github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/internal/memoize/memoize.go (about)

     1  // Copyright 2019 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Package memoize defines a "promise" abstraction that enables
     6  // memoization of the result of calling an expensive but idempotent
     7  // function.
     8  //
     9  // Call p = NewPromise(f) to obtain a promise for the future result of
    10  // calling f(), and call p.Get() to obtain that result. All calls to
    11  // p.Get return the result of a single call of f().
    12  // Get blocks if the function has not finished (or started).
    13  //
    14  // A Store is a map of arbitrary keys to promises. Use Store.Promise
    15  // to create a promise in the store. All calls to Handle(k) return the
    16  // same promise as long as it is in the store. These promises are
    17  // reference-counted and must be explicitly released. Once the last
    18  // reference is released, the promise is removed from the store.
    19  package memoize
    20  
    21  import (
    22  	"context"
    23  	"fmt"
    24  	"reflect"
    25  	"runtime/trace"
    26  	"sync"
    27  	"sync/atomic"
    28  
    29  	"golang.org/x/tools/internal/xcontext"
    30  )
    31  
    32  // Function is the type of a function that can be memoized.
    33  //
    34  // If the arg is a RefCounted, its Acquire/Release operations are called.
    35  //
    36  // The argument must not materially affect the result of the function
    37  // in ways that are not captured by the promise's key, since if
    38  // Promise.Get is called twice concurrently, with the same (implicit)
    39  // key but different arguments, the Function is called only once but
    40  // its result must be suitable for both callers.
    41  //
    42  // The main purpose of the argument is to avoid the Function closure
    43  // needing to retain large objects (in practice: the snapshot) in
    44  // memory that can be supplied at call time by any caller.
    45  type Function func(ctx context.Context, arg interface{}) interface{}
    46  
    47  // A RefCounted is a value whose functional lifetime is determined by
    48  // reference counting.
    49  //
    50  // Its Acquire method is called before the Function is invoked, and
    51  // the corresponding release is called when the Function returns.
    52  // Usually both events happen within a single call to Get, so Get
    53  // would be fine with a "borrowed" reference, but if the context is
    54  // cancelled, Get may return before the Function is complete, causing
    55  // the argument to escape, and potential premature destruction of the
    56  // value. For a reference-counted type, this requires a pair of
    57  // increment/decrement operations to extend its life.
    58  type RefCounted interface {
    59  	// Acquire prevents the value from being destroyed until the
    60  	// returned function is called.
    61  	Acquire() func()
    62  }
    63  
    64  // A Promise represents the future result of a call to a function.
    65  type Promise struct {
    66  	debug string // for observability
    67  
    68  	// refcount is the reference count in the containing Store, used by
    69  	// Store.Promise. It is guarded by Store.promisesMu on the containing Store.
    70  	refcount int32
    71  
    72  	mu sync.Mutex
    73  
    74  	// A Promise starts out IDLE, waiting for something to demand
    75  	// its evaluation. It then transitions into RUNNING state.
    76  	//
    77  	// While RUNNING, waiters tracks the number of Get calls
    78  	// waiting for a result, and the done channel is used to
    79  	// notify waiters of the next state transition. Once
    80  	// evaluation finishes, value is set, state changes to
    81  	// COMPLETED, and done is closed, unblocking waiters.
    82  	//
    83  	// Alternatively, as Get calls are cancelled, they decrement
    84  	// waiters. If it drops to zero, the inner context is
    85  	// cancelled, computation is abandoned, and state resets to
    86  	// IDLE to start the process over again.
    87  	state state
    88  	// done is set in running state, and closed when exiting it.
    89  	done chan struct{}
    90  	// cancel is set in running state. It cancels computation.
    91  	cancel context.CancelFunc
    92  	// waiters is the number of Gets outstanding.
    93  	waiters uint
    94  	// the function that will be used to populate the value
    95  	function Function
    96  	// value is set in completed state.
    97  	value interface{}
    98  }
    99  
   100  // NewPromise returns a promise for the future result of calling the
   101  // specified function.
   102  //
   103  // The debug string is used to classify promises in logs and metrics.
   104  // It should be drawn from a small set.
   105  func NewPromise(debug string, function Function) *Promise {
   106  	if function == nil {
   107  		panic("nil function")
   108  	}
   109  	return &Promise{
   110  		debug:    debug,
   111  		function: function,
   112  	}
   113  }
   114  
   115  type state int
   116  
   117  const (
   118  	stateIdle      = iota // newly constructed, or last waiter was cancelled
   119  	stateRunning          // start was called and not cancelled
   120  	stateCompleted        // function call ran to completion
   121  )
   122  
   123  // Cached returns the value associated with a promise.
   124  //
   125  // It will never cause the value to be generated.
   126  // It will return the cached value, if present.
   127  func (p *Promise) Cached() interface{} {
   128  	p.mu.Lock()
   129  	defer p.mu.Unlock()
   130  	if p.state == stateCompleted {
   131  		return p.value
   132  	}
   133  	return nil
   134  }
   135  
   136  // Get returns the value associated with a promise.
   137  //
   138  // All calls to Promise.Get on a given promise return the
   139  // same result but the function is called (to completion) at most once.
   140  //
   141  // If the value is not yet ready, the underlying function will be invoked.
   142  //
   143  // If ctx is cancelled, Get returns (nil, Canceled).
   144  // If all concurrent calls to Get are cancelled, the context provided
   145  // to the function is cancelled. A later call to Get may attempt to
   146  // call the function again.
   147  func (p *Promise) Get(ctx context.Context, arg interface{}) (interface{}, error) {
   148  	if ctx.Err() != nil {
   149  		return nil, ctx.Err()
   150  	}
   151  	p.mu.Lock()
   152  	switch p.state {
   153  	case stateIdle:
   154  		return p.run(ctx, arg)
   155  	case stateRunning:
   156  		return p.wait(ctx)
   157  	case stateCompleted:
   158  		defer p.mu.Unlock()
   159  		return p.value, nil
   160  	default:
   161  		panic("unknown state")
   162  	}
   163  }
   164  
   165  // run starts p.function and returns the result. p.mu must be locked.
   166  func (p *Promise) run(ctx context.Context, arg interface{}) (interface{}, error) {
   167  	childCtx, cancel := context.WithCancel(xcontext.Detach(ctx))
   168  	p.cancel = cancel
   169  	p.state = stateRunning
   170  	p.done = make(chan struct{})
   171  	function := p.function // Read under the lock
   172  
   173  	// Make sure that the argument isn't destroyed while we're running in it.
   174  	release := func() {}
   175  	if rc, ok := arg.(RefCounted); ok {
   176  		release = rc.Acquire()
   177  	}
   178  
   179  	go func() {
   180  		trace.WithRegion(childCtx, fmt.Sprintf("Promise.run %s", p.debug), func() {
   181  			defer release()
   182  			// Just in case the function does something expensive without checking
   183  			// the context, double-check we're still alive.
   184  			if childCtx.Err() != nil {
   185  				return
   186  			}
   187  			v := function(childCtx, arg)
   188  			if childCtx.Err() != nil {
   189  				return
   190  			}
   191  
   192  			p.mu.Lock()
   193  			defer p.mu.Unlock()
   194  			// It's theoretically possible that the promise has been cancelled out
   195  			// of the run that started us, and then started running again since we
   196  			// checked childCtx above. Even so, that should be harmless, since each
   197  			// run should produce the same results.
   198  			if p.state != stateRunning {
   199  				return
   200  			}
   201  
   202  			p.value = v
   203  			p.function = nil // aid GC
   204  			p.state = stateCompleted
   205  			close(p.done)
   206  		})
   207  	}()
   208  
   209  	return p.wait(ctx)
   210  }
   211  
   212  // wait waits for the value to be computed, or ctx to be cancelled. p.mu must be locked.
   213  func (p *Promise) wait(ctx context.Context) (interface{}, error) {
   214  	p.waiters++
   215  	done := p.done
   216  	p.mu.Unlock()
   217  
   218  	select {
   219  	case <-done:
   220  		p.mu.Lock()
   221  		defer p.mu.Unlock()
   222  		if p.state == stateCompleted {
   223  			return p.value, nil
   224  		}
   225  		return nil, nil
   226  	case <-ctx.Done():
   227  		p.mu.Lock()
   228  		defer p.mu.Unlock()
   229  		p.waiters--
   230  		if p.waiters == 0 && p.state == stateRunning {
   231  			p.cancel()
   232  			close(p.done)
   233  			p.state = stateIdle
   234  			p.done = nil
   235  			p.cancel = nil
   236  		}
   237  		return nil, ctx.Err()
   238  	}
   239  }
   240  
   241  // An EvictionPolicy controls the eviction behavior of keys in a Store when
   242  // they no longer have any references.
   243  type EvictionPolicy int
   244  
   245  const (
   246  	// ImmediatelyEvict evicts keys as soon as they no longer have references.
   247  	ImmediatelyEvict EvictionPolicy = iota
   248  
   249  	// NeverEvict does not evict keys.
   250  	NeverEvict
   251  )
   252  
   253  // A Store maps arbitrary keys to reference-counted promises.
   254  //
   255  // The zero value is a valid Store, though a store may also be created via
   256  // NewStore if a custom EvictionPolicy is required.
   257  type Store struct {
   258  	evictionPolicy EvictionPolicy
   259  
   260  	promisesMu sync.Mutex
   261  	promises   map[interface{}]*Promise
   262  }
   263  
   264  // NewStore creates a new store with the given eviction policy.
   265  func NewStore(policy EvictionPolicy) *Store {
   266  	return &Store{evictionPolicy: policy}
   267  }
   268  
   269  // Promise returns a reference-counted promise for the future result of
   270  // calling the specified function.
   271  //
   272  // Calls to Promise with the same key return the same promise, incrementing its
   273  // reference count.  The caller must call the returned function to decrement
   274  // the promise's reference count when it is no longer needed. The returned
   275  // function must not be called more than once.
   276  //
   277  // Once the last reference has been released, the promise is removed from the
   278  // store.
   279  func (store *Store) Promise(key interface{}, function Function) (*Promise, func()) {
   280  	store.promisesMu.Lock()
   281  	p, ok := store.promises[key]
   282  	if !ok {
   283  		p = NewPromise(reflect.TypeOf(key).String(), function)
   284  		if store.promises == nil {
   285  			store.promises = map[interface{}]*Promise{}
   286  		}
   287  		store.promises[key] = p
   288  	}
   289  	p.refcount++
   290  	store.promisesMu.Unlock()
   291  
   292  	var released int32
   293  	release := func() {
   294  		if !atomic.CompareAndSwapInt32(&released, 0, 1) {
   295  			panic("release called more than once")
   296  		}
   297  		store.promisesMu.Lock()
   298  
   299  		p.refcount--
   300  		if p.refcount == 0 && store.evictionPolicy != NeverEvict {
   301  			// Inv: if p.refcount > 0, then store.promises[key] == p.
   302  			delete(store.promises, key)
   303  		}
   304  		store.promisesMu.Unlock()
   305  	}
   306  
   307  	return p, release
   308  }
   309  
   310  // Stats returns the number of each type of key in the store.
   311  func (s *Store) Stats() map[reflect.Type]int {
   312  	result := map[reflect.Type]int{}
   313  
   314  	s.promisesMu.Lock()
   315  	defer s.promisesMu.Unlock()
   316  
   317  	for k := range s.promises {
   318  		result[reflect.TypeOf(k)]++
   319  	}
   320  	return result
   321  }
   322  
   323  // DebugOnlyIterate iterates through the store and, for each completed
   324  // promise, calls f(k, v) for the map key k and function result v.  It
   325  // should only be used for debugging purposes.
   326  func (s *Store) DebugOnlyIterate(f func(k, v interface{})) {
   327  	s.promisesMu.Lock()
   328  	defer s.promisesMu.Unlock()
   329  
   330  	for k, p := range s.promises {
   331  		if v := p.Cached(); v != nil {
   332  			f(k, v)
   333  		}
   334  	}
   335  }