github.com/uber/kraken@v0.1.4/utils/dedup/request_cache.go (about)

     1  // Copyright (c) 2016-2019 Uber Technologies, Inc.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  package dedup
    15  
    16  import (
    17  	"errors"
    18  	"sync"
    19  	"time"
    20  
    21  	"github.com/andres-erbsen/clock"
    22  )
    23  
    24  // RequestCacheConfig defines RequestCache configuration.
    25  type RequestCacheConfig struct {
    26  	NotFoundTTL     time.Duration `yaml:"not_found_ttl"`
    27  	ErrorTTL        time.Duration `yaml:"error_ttl"`
    28  	CleanupInterval time.Duration `yaml:"cleanup_interval"`
    29  	NumWorkers      int           `yaml:"num_workers"`
    30  	BusyTimeout     time.Duration `yaml:"busy_timeout"`
    31  }
    32  
    33  func (c *RequestCacheConfig) applyDefaults() {
    34  	// TODO(codyg): If the cached error TTL is lower than the interval in which
    35  	// clients are polling a 202 endpoint, then it is possible that the client
    36  	// will never hit the actual error because it expires in between requests.
    37  	if c.NotFoundTTL == 0 {
    38  		c.NotFoundTTL = 15 * time.Second
    39  	}
    40  	if c.ErrorTTL == 0 {
    41  		c.ErrorTTL = 15 * time.Second
    42  	}
    43  	if c.CleanupInterval == 0 {
    44  		c.CleanupInterval = 5 * time.Second
    45  	}
    46  	if c.NumWorkers == 0 {
    47  		c.NumWorkers = 10000
    48  	}
    49  	if c.BusyTimeout == 0 {
    50  		c.BusyTimeout = 5 * time.Second
    51  	}
    52  }
    53  
    54  // RequestCache errors.
    55  var (
    56  	ErrRequestPending = errors.New("request pending")
    57  	ErrWorkersBusy    = errors.New("no workers available to handle request")
    58  )
    59  
    60  type cachedError struct {
    61  	err       error
    62  	expiresAt time.Time
    63  }
    64  
    65  func (e *cachedError) expired(now time.Time) bool {
    66  	return now.After(e.expiresAt)
    67  }
    68  
    69  // Request defines functions which encapsulate a request.
    70  type Request func() error
    71  
    72  // ErrorMatcher defines functions which RequestCache uses to detect user defined
    73  // errors.
    74  type ErrorMatcher func(error) bool
    75  
    76  // RequestCache tracks pending requests and caches errors for configurable TTLs.
    77  // It is used to prevent request duplication and DDOS-ing external components.
    78  // Each request is represented by an arbitrary id string determined by the user.
    79  type RequestCache struct {
    80  	config RequestCacheConfig
    81  	clk    clock.Clock
    82  
    83  	mu         sync.Mutex // Protects access to the following fields:
    84  	pending    map[string]bool
    85  	errors     map[string]*cachedError
    86  	lastClean  time.Time
    87  	isNotFound ErrorMatcher
    88  
    89  	numWorkers chan struct{}
    90  }
    91  
    92  // NewRequestCache creates a new RequestCache.
    93  func NewRequestCache(config RequestCacheConfig, clk clock.Clock) *RequestCache {
    94  	config.applyDefaults()
    95  	return &RequestCache{
    96  		config:     config,
    97  		clk:        clk,
    98  		pending:    make(map[string]bool),
    99  		errors:     make(map[string]*cachedError),
   100  		lastClean:  clk.Now(),
   101  		isNotFound: func(error) bool { return false },
   102  		numWorkers: make(chan struct{}, config.NumWorkers),
   103  	}
   104  }
   105  
   106  // SetNotFound sets the ErrorMatcher for activating the configured NotFoundTTL
   107  // for errors returned by Request functions.
   108  func (c *RequestCache) SetNotFound(m ErrorMatcher) {
   109  	c.mu.Lock()
   110  	defer c.mu.Unlock()
   111  
   112  	c.isNotFound = m
   113  }
   114  
   115  // Start concurrently runs r under the given id. Any error returned by r will be
   116  // cached for the configured TTL. If there is already a function executing under
   117  // id, Start returns ErrRequestPending. If there are no available workers to run
   118  // r, Start returns ErrWorkersBusy.
   119  func (c *RequestCache) Start(id string, r Request) error {
   120  	if err := c.reserve(id); err != nil {
   121  		return err
   122  	}
   123  	if err := c.reserveWorker(); err != nil {
   124  		c.release(id)
   125  		return err
   126  	}
   127  	go func() {
   128  		defer c.releaseWorker()
   129  		c.run(id, r)
   130  	}()
   131  	return nil
   132  }
   133  
   134  func (c *RequestCache) reserve(id string) error {
   135  	c.mu.Lock()
   136  	defer c.mu.Unlock()
   137  
   138  	// Periodically remove expired errors.
   139  	if c.clk.Now().Sub(c.lastClean) > c.config.CleanupInterval {
   140  		for id, cerr := range c.errors {
   141  			if cerr.expired(c.clk.Now()) {
   142  				delete(c.errors, id)
   143  			}
   144  		}
   145  		c.lastClean = c.clk.Now()
   146  	}
   147  
   148  	if c.pending[id] {
   149  		return ErrRequestPending
   150  	}
   151  	if cerr, ok := c.errors[id]; ok && !cerr.expired(c.clk.Now()) {
   152  		return cerr.err
   153  	}
   154  
   155  	c.pending[id] = true
   156  
   157  	return nil
   158  }
   159  
   160  func (c *RequestCache) run(id string, r Request) {
   161  	if err := r(); err != nil {
   162  		c.error(id, err)
   163  		return
   164  	}
   165  	c.release(id)
   166  }
   167  
   168  func (c *RequestCache) release(id string) {
   169  	c.mu.Lock()
   170  	defer c.mu.Unlock()
   171  
   172  	delete(c.pending, id)
   173  }
   174  
   175  func (c *RequestCache) error(id string, err error) {
   176  	c.mu.Lock()
   177  	defer c.mu.Unlock()
   178  
   179  	var ttl time.Duration
   180  	if c.isNotFound(err) {
   181  		ttl = c.config.NotFoundTTL
   182  	} else {
   183  		ttl = c.config.ErrorTTL
   184  	}
   185  	delete(c.pending, id)
   186  	c.errors[id] = &cachedError{err: err, expiresAt: c.clk.Now().Add(ttl)}
   187  }
   188  
   189  func (c *RequestCache) reserveWorker() error {
   190  	select {
   191  	case c.numWorkers <- struct{}{}:
   192  		return nil
   193  	case <-c.clk.After(c.config.BusyTimeout):
   194  		return ErrWorkersBusy
   195  	}
   196  }
   197  
   198  func (c *RequestCache) releaseWorker() {
   199  	<-c.numWorkers
   200  }