github.com/gramework/gramework@v1.8.1-0.20231027140105-82555c9057f5/app_cache.go (about) 1 // +build cache 2 3 package gramework 4 5 import ( 6 "encoding/json" 7 "errors" 8 "time" 9 10 "github.com/VictoriaMetrics/fastcache" 11 ) 12 13 func (opts *CacheOptions) validate() error { 14 if opts.TTL <= 0 { 15 return errors.New("TTL must be grater than 0") 16 } 17 if opts.CacheKey == nil { 18 opts.CacheKey = defaultCacheOpts.CacheKey 19 } 20 if opts.Cacheable == nil { 21 opts.Cacheable = defaultCacheOpts.Cacheable 22 } 23 24 return nil 25 } 26 27 var defaultCacheOpts = NewCacheOptions() 28 29 // NewCacheOptions returns a cache options with default settings. 30 func NewCacheOptions() *CacheOptions { 31 return &CacheOptions{ 32 TTL: 30 * time.Second, 33 Cacheable: func(ctx *Context) bool { 34 if len(ctx.Request.Header.Peek("Authentication")) > 0 { 35 return false 36 } 37 38 if len(ctx.Cookies.Storage) > 0 { 39 return false 40 } 41 42 return true 43 }, 44 CacheKey: func(ctx *Context) []byte { 45 return ctx.Path() 46 }, 47 } 48 } 49 50 // CacheFor is a shortcut to set ttl easily. See app.Cache() for docs. 51 func (app *App) CacheFor(handler interface{}, ttl time.Duration) func(ctx *Context) { 52 opts := app.getCacheOpts() 53 54 opts.TTL = ttl 55 return app.Cache(handler, opts) 56 } 57 58 // Cache wrapper will cache given handler using provided options. If options parameter omitted, 59 // this function will use default options. 60 // 61 // NOTE: Please, your CacheOptions' TTL must be more than 0. 62 func (app *App) Cache(handler interface{}, options ...*CacheOptions) func(ctx *Context) { 63 opts := app.getCacheOpts(options...) 64 65 if err := opts.validate(); err != nil { 66 app.Logger.WithError(err).Fatal("could not initialize cache middleware: check options") 67 } 68 69 wrappedHandler := app.defaultRouter.determineHandler(handler) 70 71 if opts.ReadCache == nil || opts.StoreCache == nil { 72 cache := fastcache.New(1) 73 opts.ReadCache = readFastCache(cache) 74 opts.StoreCache = storeFastCache(cache) 75 } 76 77 return func(ctx *Context) { 78 if opts.Cacheable(ctx) { 79 cacheKey := opts.CacheKey(ctx) 80 if value, isValid := opts.ReadCache(ctx, cacheKey); isValid { 81 serializedHeaders, isValid := opts.ReadCache(ctx, append(cacheKey, []byte("-headers")...)) 82 if isValid { 83 headers := map[string]string{} 84 err := json.Unmarshal(serializedHeaders, &headers) 85 if err == nil { 86 for name, value := range headers { 87 ctx.Response.Header.Set(name, value) 88 } 89 ctx.Response.SetBody(value) 90 return 91 } 92 } 93 } 94 95 wrappedHandler(ctx) 96 97 b := ctx.Response.Body() 98 99 opts.StoreCache(ctx, cacheKey, b, opts.TTL) 100 headers, ok := serializeHeaders(ctx, opts) 101 if ok { 102 opts.StoreCache(ctx, append(cacheKey, []byte("-headers")...), headers, opts.TTL) 103 } 104 return 105 } 106 107 wrappedHandler(ctx) 108 } 109 } 110 111 func serializeHeaders(ctx *Context, opts *CacheOptions) ([]byte, bool) { 112 headers := map[string]string{ 113 "Content-Type": string(ctx.Response.Header.Peek("Content-Type")), 114 "Content-Length": string(ctx.Response.Header.Peek("Content-Length")), 115 } 116 for _, header := range opts.CacheableHeaders { 117 headers[header] = string(ctx.Response.Header.Peek(header)) 118 } 119 for _, header := range opts.NonCacheableHeaders { 120 delete(headers, header) 121 } 122 serialized, err := json.Marshal(headers) 123 return serialized, err == nil 124 } 125 126 func readFastCache(cache *fastcache.Cache) func(_ *Context, key []byte) (value []byte, isValid bool) { 127 return func(_ *Context, key []byte) ([]byte, bool) { 128 return cache.GetWithTimeout(nil, key) 129 } 130 } 131 132 func storeFastCache(cache *fastcache.Cache) func(_ *Context, key, value []byte, ttl time.Duration) { 133 return func(_ *Context, key, value []byte, ttl time.Duration) { 134 cache.SetWithTimeout(key, value, ttl) 135 } 136 } 137 138 func (app *App) getCacheOpts(options ...*CacheOptions) *CacheOptions { 139 opts := defaultCacheOpts 140 switch { 141 case len(options) > 1: 142 app.Logger.Warn("got more than one set of cache options: using the first one.") 143 fallthrough 144 case len(options) == 1: 145 if options[0] != nil { 146 opts = options[0] 147 } 148 case app.DefaultCacheOptions != nil: 149 opts = app.DefaultCacheOptions 150 } 151 152 return opts 153 }