github.com/cockroachdb/tools@v0.0.0-20230222021103-a6d27438930d/internal/memoize/memoize.go (about) 1 // Copyright 2019 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 // Package memoize defines a "promise" abstraction that enables 6 // memoization of the result of calling an expensive but idempotent 7 // function. 8 // 9 // Call p = NewPromise(f) to obtain a promise for the future result of 10 // calling f(), and call p.Get() to obtain that result. All calls to 11 // p.Get return the result of a single call of f(). 12 // Get blocks if the function has not finished (or started). 13 // 14 // A Store is a map of arbitrary keys to promises. Use Store.Promise 15 // to create a promise in the store. All calls to Handle(k) return the 16 // same promise as long as it is in the store. These promises are 17 // reference-counted and must be explicitly released. Once the last 18 // reference is released, the promise is removed from the store. 19 package memoize 20 21 import ( 22 "context" 23 "fmt" 24 "reflect" 25 "runtime/trace" 26 "sync" 27 "sync/atomic" 28 29 "golang.org/x/tools/internal/xcontext" 30 ) 31 32 // Function is the type of a function that can be memoized. 33 // 34 // If the arg is a RefCounted, its Acquire/Release operations are called. 35 // 36 // The argument must not materially affect the result of the function 37 // in ways that are not captured by the promise's key, since if 38 // Promise.Get is called twice concurrently, with the same (implicit) 39 // key but different arguments, the Function is called only once but 40 // its result must be suitable for both callers. 41 // 42 // The main purpose of the argument is to avoid the Function closure 43 // needing to retain large objects (in practice: the snapshot) in 44 // memory that can be supplied at call time by any caller. 45 type Function func(ctx context.Context, arg interface{}) interface{} 46 47 // A RefCounted is a value whose functional lifetime is determined by 48 // reference counting. 49 // 50 // Its Acquire method is called before the Function is invoked, and 51 // the corresponding release is called when the Function returns. 52 // Usually both events happen within a single call to Get, so Get 53 // would be fine with a "borrowed" reference, but if the context is 54 // cancelled, Get may return before the Function is complete, causing 55 // the argument to escape, and potential premature destruction of the 56 // value. For a reference-counted type, this requires a pair of 57 // increment/decrement operations to extend its life. 58 type RefCounted interface { 59 // Acquire prevents the value from being destroyed until the 60 // returned function is called. 61 Acquire() func() 62 } 63 64 // A Promise represents the future result of a call to a function. 65 type Promise struct { 66 debug string // for observability 67 68 // refcount is the reference count in the containing Store, used by 69 // Store.Promise. It is guarded by Store.promisesMu on the containing Store. 70 refcount int32 71 72 mu sync.Mutex 73 74 // A Promise starts out IDLE, waiting for something to demand 75 // its evaluation. It then transitions into RUNNING state. 76 // 77 // While RUNNING, waiters tracks the number of Get calls 78 // waiting for a result, and the done channel is used to 79 // notify waiters of the next state transition. Once 80 // evaluation finishes, value is set, state changes to 81 // COMPLETED, and done is closed, unblocking waiters. 82 // 83 // Alternatively, as Get calls are cancelled, they decrement 84 // waiters. If it drops to zero, the inner context is 85 // cancelled, computation is abandoned, and state resets to 86 // IDLE to start the process over again. 87 state state 88 // done is set in running state, and closed when exiting it. 89 done chan struct{} 90 // cancel is set in running state. It cancels computation. 91 cancel context.CancelFunc 92 // waiters is the number of Gets outstanding. 93 waiters uint 94 // the function that will be used to populate the value 95 function Function 96 // value is set in completed state. 97 value interface{} 98 } 99 100 // NewPromise returns a promise for the future result of calling the 101 // specified function. 102 // 103 // The debug string is used to classify promises in logs and metrics. 104 // It should be drawn from a small set. 105 func NewPromise(debug string, function Function) *Promise { 106 if function == nil { 107 panic("nil function") 108 } 109 return &Promise{ 110 debug: debug, 111 function: function, 112 } 113 } 114 115 type state int 116 117 const ( 118 stateIdle = iota // newly constructed, or last waiter was cancelled 119 stateRunning // start was called and not cancelled 120 stateCompleted // function call ran to completion 121 ) 122 123 // Cached returns the value associated with a promise. 124 // 125 // It will never cause the value to be generated. 126 // It will return the cached value, if present. 127 func (p *Promise) Cached() interface{} { 128 p.mu.Lock() 129 defer p.mu.Unlock() 130 if p.state == stateCompleted { 131 return p.value 132 } 133 return nil 134 } 135 136 // Get returns the value associated with a promise. 137 // 138 // All calls to Promise.Get on a given promise return the 139 // same result but the function is called (to completion) at most once. 140 // 141 // If the value is not yet ready, the underlying function will be invoked. 142 // 143 // If ctx is cancelled, Get returns (nil, Canceled). 144 // If all concurrent calls to Get are cancelled, the context provided 145 // to the function is cancelled. A later call to Get may attempt to 146 // call the function again. 147 func (p *Promise) Get(ctx context.Context, arg interface{}) (interface{}, error) { 148 if ctx.Err() != nil { 149 return nil, ctx.Err() 150 } 151 p.mu.Lock() 152 switch p.state { 153 case stateIdle: 154 return p.run(ctx, arg) 155 case stateRunning: 156 return p.wait(ctx) 157 case stateCompleted: 158 defer p.mu.Unlock() 159 return p.value, nil 160 default: 161 panic("unknown state") 162 } 163 } 164 165 // run starts p.function and returns the result. p.mu must be locked. 166 func (p *Promise) run(ctx context.Context, arg interface{}) (interface{}, error) { 167 childCtx, cancel := context.WithCancel(xcontext.Detach(ctx)) 168 p.cancel = cancel 169 p.state = stateRunning 170 p.done = make(chan struct{}) 171 function := p.function // Read under the lock 172 173 // Make sure that the argument isn't destroyed while we're running in it. 174 release := func() {} 175 if rc, ok := arg.(RefCounted); ok { 176 release = rc.Acquire() 177 } 178 179 go func() { 180 trace.WithRegion(childCtx, fmt.Sprintf("Promise.run %s", p.debug), func() { 181 defer release() 182 // Just in case the function does something expensive without checking 183 // the context, double-check we're still alive. 184 if childCtx.Err() != nil { 185 return 186 } 187 v := function(childCtx, arg) 188 if childCtx.Err() != nil { 189 return 190 } 191 192 p.mu.Lock() 193 defer p.mu.Unlock() 194 // It's theoretically possible that the promise has been cancelled out 195 // of the run that started us, and then started running again since we 196 // checked childCtx above. Even so, that should be harmless, since each 197 // run should produce the same results. 198 if p.state != stateRunning { 199 return 200 } 201 202 p.value = v 203 p.function = nil // aid GC 204 p.state = stateCompleted 205 close(p.done) 206 }) 207 }() 208 209 return p.wait(ctx) 210 } 211 212 // wait waits for the value to be computed, or ctx to be cancelled. p.mu must be locked. 213 func (p *Promise) wait(ctx context.Context) (interface{}, error) { 214 p.waiters++ 215 done := p.done 216 p.mu.Unlock() 217 218 select { 219 case <-done: 220 p.mu.Lock() 221 defer p.mu.Unlock() 222 if p.state == stateCompleted { 223 return p.value, nil 224 } 225 return nil, nil 226 case <-ctx.Done(): 227 p.mu.Lock() 228 defer p.mu.Unlock() 229 p.waiters-- 230 if p.waiters == 0 && p.state == stateRunning { 231 p.cancel() 232 close(p.done) 233 p.state = stateIdle 234 p.done = nil 235 p.cancel = nil 236 } 237 return nil, ctx.Err() 238 } 239 } 240 241 // An EvictionPolicy controls the eviction behavior of keys in a Store when 242 // they no longer have any references. 243 type EvictionPolicy int 244 245 const ( 246 // ImmediatelyEvict evicts keys as soon as they no longer have references. 247 ImmediatelyEvict EvictionPolicy = iota 248 249 // NeverEvict does not evict keys. 250 NeverEvict 251 ) 252 253 // A Store maps arbitrary keys to reference-counted promises. 254 // 255 // The zero value is a valid Store, though a store may also be created via 256 // NewStore if a custom EvictionPolicy is required. 257 type Store struct { 258 evictionPolicy EvictionPolicy 259 260 promisesMu sync.Mutex 261 promises map[interface{}]*Promise 262 } 263 264 // NewStore creates a new store with the given eviction policy. 265 func NewStore(policy EvictionPolicy) *Store { 266 return &Store{evictionPolicy: policy} 267 } 268 269 // Promise returns a reference-counted promise for the future result of 270 // calling the specified function. 271 // 272 // Calls to Promise with the same key return the same promise, incrementing its 273 // reference count. The caller must call the returned function to decrement 274 // the promise's reference count when it is no longer needed. The returned 275 // function must not be called more than once. 276 // 277 // Once the last reference has been released, the promise is removed from the 278 // store. 279 func (store *Store) Promise(key interface{}, function Function) (*Promise, func()) { 280 store.promisesMu.Lock() 281 p, ok := store.promises[key] 282 if !ok { 283 p = NewPromise(reflect.TypeOf(key).String(), function) 284 if store.promises == nil { 285 store.promises = map[interface{}]*Promise{} 286 } 287 store.promises[key] = p 288 } 289 p.refcount++ 290 store.promisesMu.Unlock() 291 292 var released int32 293 release := func() { 294 if !atomic.CompareAndSwapInt32(&released, 0, 1) { 295 panic("release called more than once") 296 } 297 store.promisesMu.Lock() 298 299 p.refcount-- 300 if p.refcount == 0 && store.evictionPolicy != NeverEvict { 301 // Inv: if p.refcount > 0, then store.promises[key] == p. 302 delete(store.promises, key) 303 } 304 store.promisesMu.Unlock() 305 } 306 307 return p, release 308 } 309 310 // Stats returns the number of each type of key in the store. 311 func (s *Store) Stats() map[reflect.Type]int { 312 result := map[reflect.Type]int{} 313 314 s.promisesMu.Lock() 315 defer s.promisesMu.Unlock() 316 317 for k := range s.promises { 318 result[reflect.TypeOf(k)]++ 319 } 320 return result 321 } 322 323 // DebugOnlyIterate iterates through the store and, for each completed 324 // promise, calls f(k, v) for the map key k and function result v. It 325 // should only be used for debugging purposes. 326 func (s *Store) DebugOnlyIterate(f func(k, v interface{})) { 327 s.promisesMu.Lock() 328 defer s.promisesMu.Unlock() 329 330 for k, p := range s.promises { 331 if v := p.Cached(); v != nil { 332 f(k, v) 333 } 334 } 335 }