github.com/gramework/gramework@v1.8.1-0.20231027140105-82555c9057f5/app_cache.go (about)

     1  // +build cache
     2  
     3  package gramework
     4  
     5  import (
     6  	"encoding/json"
     7  	"errors"
     8  	"time"
     9  
    10  	"github.com/VictoriaMetrics/fastcache"
    11  )
    12  
    13  func (opts *CacheOptions) validate() error {
    14  	if opts.TTL <= 0 {
    15  		return errors.New("TTL must be grater than 0")
    16  	}
    17  	if opts.CacheKey == nil {
    18  		opts.CacheKey = defaultCacheOpts.CacheKey
    19  	}
    20  	if opts.Cacheable == nil {
    21  		opts.Cacheable = defaultCacheOpts.Cacheable
    22  	}
    23  
    24  	return nil
    25  }
    26  
    27  var defaultCacheOpts = NewCacheOptions()
    28  
    29  // NewCacheOptions returns a cache options with default settings.
    30  func NewCacheOptions() *CacheOptions {
    31  	return &CacheOptions{
    32  		TTL: 30 * time.Second,
    33  		Cacheable: func(ctx *Context) bool {
    34  			if len(ctx.Request.Header.Peek("Authentication")) > 0 {
    35  				return false
    36  			}
    37  
    38  			if len(ctx.Cookies.Storage) > 0 {
    39  				return false
    40  			}
    41  
    42  			return true
    43  		},
    44  		CacheKey: func(ctx *Context) []byte {
    45  			return ctx.Path()
    46  		},
    47  	}
    48  }
    49  
    50  // CacheFor is a shortcut to set ttl easily. See app.Cache() for docs.
    51  func (app *App) CacheFor(handler interface{}, ttl time.Duration) func(ctx *Context) {
    52  	opts := app.getCacheOpts()
    53  
    54  	opts.TTL = ttl
    55  	return app.Cache(handler, opts)
    56  }
    57  
    58  // Cache wrapper will cache given handler using provided options. If options parameter omitted,
    59  // this function will use default options.
    60  //
    61  // NOTE: Please, your CacheOptions' TTL must be more than 0.
    62  func (app *App) Cache(handler interface{}, options ...*CacheOptions) func(ctx *Context) {
    63  	opts := app.getCacheOpts(options...)
    64  
    65  	if err := opts.validate(); err != nil {
    66  		app.Logger.WithError(err).Fatal("could not initialize cache middleware: check options")
    67  	}
    68  
    69  	wrappedHandler := app.defaultRouter.determineHandler(handler)
    70  
    71  	if opts.ReadCache == nil || opts.StoreCache == nil {
    72  		cache := fastcache.New(1)
    73  		opts.ReadCache = readFastCache(cache)
    74  		opts.StoreCache = storeFastCache(cache)
    75  	}
    76  
    77  	return func(ctx *Context) {
    78  		if opts.Cacheable(ctx) {
    79  			cacheKey := opts.CacheKey(ctx)
    80  			if value, isValid := opts.ReadCache(ctx, cacheKey); isValid {
    81  				serializedHeaders, isValid := opts.ReadCache(ctx, append(cacheKey, []byte("-headers")...))
    82  				if isValid {
    83  					headers := map[string]string{}
    84  					err := json.Unmarshal(serializedHeaders, &headers)
    85  					if err == nil {
    86  						for name, value := range headers {
    87  							ctx.Response.Header.Set(name, value)
    88  						}
    89  						ctx.Response.SetBody(value)
    90  						return
    91  					}
    92  				}
    93  			}
    94  
    95  			wrappedHandler(ctx)
    96  
    97  			b := ctx.Response.Body()
    98  
    99  			opts.StoreCache(ctx, cacheKey, b, opts.TTL)
   100  			headers, ok := serializeHeaders(ctx, opts)
   101  			if ok {
   102  				opts.StoreCache(ctx, append(cacheKey, []byte("-headers")...), headers, opts.TTL)
   103  			}
   104  			return
   105  		}
   106  
   107  		wrappedHandler(ctx)
   108  	}
   109  }
   110  
   111  func serializeHeaders(ctx *Context, opts *CacheOptions) ([]byte, bool) {
   112  	headers := map[string]string{
   113  		"Content-Type":   string(ctx.Response.Header.Peek("Content-Type")),
   114  		"Content-Length": string(ctx.Response.Header.Peek("Content-Length")),
   115  	}
   116  	for _, header := range opts.CacheableHeaders {
   117  		headers[header] = string(ctx.Response.Header.Peek(header))
   118  	}
   119  	for _, header := range opts.NonCacheableHeaders {
   120  		delete(headers, header)
   121  	}
   122  	serialized, err := json.Marshal(headers)
   123  	return serialized, err == nil
   124  }
   125  
   126  func readFastCache(cache *fastcache.Cache) func(_ *Context, key []byte) (value []byte, isValid bool) {
   127  	return func(_ *Context, key []byte) ([]byte, bool) {
   128  		return cache.GetWithTimeout(nil, key)
   129  	}
   130  }
   131  
   132  func storeFastCache(cache *fastcache.Cache) func(_ *Context, key, value []byte, ttl time.Duration) {
   133  	return func(_ *Context, key, value []byte, ttl time.Duration) {
   134  		cache.SetWithTimeout(key, value, ttl)
   135  	}
   136  }
   137  
   138  func (app *App) getCacheOpts(options ...*CacheOptions) *CacheOptions {
   139  	opts := defaultCacheOpts
   140  	switch {
   141  	case len(options) > 1:
   142  		app.Logger.Warn("got more than one set of cache options: using the first one.")
   143  		fallthrough
   144  	case len(options) == 1:
   145  		if options[0] != nil {
   146  			opts = options[0]
   147  		}
   148  	case app.DefaultCacheOptions != nil:
   149  		opts = app.DefaultCacheOptions
   150  	}
   151  
   152  	return opts
   153  }