github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/mw_redis_cache.go (about) 1 package gateway 2 3 import ( 4 "bufio" 5 "bytes" 6 "crypto/md5" 7 "encoding/base64" 8 "encoding/hex" 9 "errors" 10 "hash" 11 "io" 12 "io/ioutil" 13 "net/http" 14 "strconv" 15 "strings" 16 "time" 17 18 "golang.org/x/sync/singleflight" 19 20 "github.com/TykTechnologies/murmur3" 21 "github.com/TykTechnologies/tyk/headers" 22 "github.com/TykTechnologies/tyk/regexp" 23 "github.com/TykTechnologies/tyk/request" 24 "github.com/TykTechnologies/tyk/storage" 25 ) 26 27 const ( 28 upstreamCacheHeader = "x-tyk-cache-action-set" 29 upstreamCacheTTLHeader = "x-tyk-cache-action-set-ttl" 30 ) 31 32 // RedisCacheMiddleware is a caching middleware that will pull data from Redis instead of the upstream proxy 33 type RedisCacheMiddleware struct { 34 BaseMiddleware 35 CacheStore storage.Handler 36 sh SuccessHandler 37 singleFlight singleflight.Group 38 } 39 40 func (m *RedisCacheMiddleware) Name() string { 41 return "RedisCacheMiddleware" 42 } 43 44 func (m *RedisCacheMiddleware) Init() { 45 m.sh = SuccessHandler{m.BaseMiddleware} 46 } 47 48 func (m *RedisCacheMiddleware) EnabledForSpec() bool { 49 return m.Spec.CacheOptions.EnableCache 50 } 51 52 func (m *RedisCacheMiddleware) CreateCheckSum(req *http.Request, keyName string, regex string, additionalKeyFromHeaders string) (string, error) { 53 h := md5.New() 54 io.WriteString(h, req.Method) 55 io.WriteString(h, "-"+req.URL.String()) 56 if additionalKeyFromHeaders != "" { 57 io.WriteString(h, "-"+additionalKeyFromHeaders) 58 } 59 60 if e := addBodyHash(req, regex, h); e != nil { 61 return "", e 62 } 63 64 reqChecksum := hex.EncodeToString(h.Sum(nil)) 65 return m.Spec.APIID + keyName + reqChecksum, nil 66 } 67 68 func addBodyHash(req *http.Request, regex string, h hash.Hash) (err error) { 69 if !isBodyHashRequired(req) { 70 return nil 71 } 72 73 bodyBytes, err := readBody(req) 74 if err != nil { 75 return err 76 } 77 78 mur := murmur3.New128() 79 if regex == "" { 80 mur.Write(bodyBytes) 81 io.WriteString(h, "-"+hex.EncodeToString(mur.Sum(nil))) 82 return nil 83 } 84 r, err := regexp.Compile(regex) 85 if err != nil { 86 return err 87 } 88 89 if match := r.Find(bodyBytes); match != nil { 90 mur.Write(match) 91 io.WriteString(h, "-"+hex.EncodeToString(mur.Sum(nil))) 92 } 93 return nil 94 } 95 96 func readBody(req *http.Request) (bodyBytes []byte, err error) { 97 if n, ok := req.Body.(nopCloser); ok { 98 n.Seek(0, io.SeekStart) 99 bodyBytes, err = ioutil.ReadAll(n) 100 if err != nil { 101 return nil, err 102 } 103 n.Seek(0, io.SeekStart) // reset for any next read. 104 return 105 } 106 107 req.Body = copyBody(req.Body) 108 bodyBytes, err = ioutil.ReadAll(req.Body) 109 if err != nil { 110 return nil, err 111 } 112 req.Body.(nopCloser).Seek(0, io.SeekStart) // reset for any next read. 113 return 114 } 115 116 func isBodyHashRequired(request *http.Request) bool { 117 return request.Body != nil && 118 (request.Method == http.MethodPost || 119 request.Method == http.MethodPut || 120 request.Method == http.MethodPatch) 121 122 } 123 124 func (m *RedisCacheMiddleware) getTimeTTL(cacheTTL int64) string { 125 timeNow := time.Now().Unix() 126 newTTL := timeNow + cacheTTL 127 asStr := strconv.Itoa(int(newTTL)) 128 return asStr 129 } 130 131 func (m *RedisCacheMiddleware) isTimeStampExpired(timestamp string) bool { 132 now := time.Now() 133 134 i, err := strconv.ParseInt(timestamp, 10, 64) 135 if err != nil { 136 log.Error(err) 137 } 138 tm := time.Unix(i, 0) 139 140 log.Debug("Time Now: ", now) 141 log.Debug("Expires: ", tm) 142 if tm.Before(now) { 143 log.Debug("Expriy caught in TS!") 144 return true 145 } 146 147 return false 148 } 149 150 func (m *RedisCacheMiddleware) encodePayload(payload, timestamp string) string { 151 sEnc := base64.StdEncoding.EncodeToString([]byte(payload)) 152 return sEnc + "|" + timestamp 153 } 154 155 func (m *RedisCacheMiddleware) decodePayload(payload string) (string, string, error) { 156 data := strings.Split(payload, "|") 157 switch len(data) { 158 case 1: 159 return data[0], "", nil 160 case 2: 161 sDec, err := base64.StdEncoding.DecodeString(data[0]) 162 if err != nil { 163 return "", "", err 164 } 165 166 return string(sDec), data[1], nil 167 } 168 return "", "", errors.New("Decoding failed, array length wrong") 169 } 170 171 // ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail 172 func (m *RedisCacheMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _ interface{}) (error, int) { 173 var stat RequestStatus 174 var cacheKeyRegex string 175 var cacheMeta *EndPointCacheMeta 176 177 _, versionPaths, _, _ := m.Spec.Version(r) 178 isVirtual, _ := m.Spec.CheckSpecMatchesStatus(r, versionPaths, VirtualPath) 179 180 // Lets see if we can throw a sledgehammer at this 181 if m.Spec.CacheOptions.CacheAllSafeRequests && isSafeMethod(r.Method) { 182 stat = StatusCached 183 } 184 if stat != StatusCached { 185 // New request checker, more targeted, less likely to fail 186 found, meta := m.Spec.CheckSpecMatchesStatus(r, versionPaths, Cached) 187 if found { 188 cacheMeta = meta.(*EndPointCacheMeta) 189 stat = StatusCached 190 cacheKeyRegex = cacheMeta.CacheKeyRegex 191 } 192 } 193 194 // Cached route matched, let go 195 if stat != StatusCached { 196 return nil, http.StatusOK 197 } 198 token := ctxGetAuthToken(r) 199 200 // No authentication data? use the IP. 201 if token == "" { 202 token = request.RealIP(r) 203 } 204 205 var errCreatingChecksum bool 206 var retBlob string 207 key, err := m.CreateCheckSum(r, token, cacheKeyRegex, m.getCacheKeyFromHeaders(r)) 208 if err != nil { 209 log.Debug("Error creating checksum. Skipping cache check") 210 errCreatingChecksum = true 211 } else { 212 v, sfErr, _ := m.singleFlight.Do(key, func() (interface{}, error) { 213 return m.CacheStore.GetKey(key) 214 }) 215 retBlob = v.(string) 216 err = sfErr 217 } 218 219 if err != nil { 220 if !errCreatingChecksum { 221 log.Debug("Cache enabled, but record not found") 222 } 223 // Pass through to proxy AND CACHE RESULT 224 225 var resVal *http.Response 226 if isVirtual { 227 log.Debug("This is a virtual function") 228 vp := VirtualEndpoint{BaseMiddleware: m.BaseMiddleware} 229 vp.Init() 230 resVal = vp.ServeHTTPForCache(w, r, nil) 231 } else { 232 // This passes through and will write the value to the writer, but spit out a copy for the cache 233 log.Debug("Not virtual, passing") 234 if newURL := ctxGetURLRewriteTarget(r); newURL != nil { 235 r.URL = newURL 236 ctxSetURLRewriteTarget(r, nil) 237 } 238 if newMethod := ctxGetTransformRequestMethod(r); newMethod != "" { 239 r.Method = newMethod 240 ctxSetTransformRequestMethod(r, "") 241 } 242 sr := m.sh.ServeHTTPWithCache(w, r) 243 resVal = sr.Response 244 } 245 246 cacheThisRequest := true 247 cacheTTL := m.Spec.CacheOptions.CacheTimeout 248 249 if resVal == nil { 250 log.Warning("Upstream request must have failed, response is empty") 251 return nil, mwStatusRespond 252 } 253 254 cacheOnlyResponseCodes := m.Spec.CacheOptions.CacheOnlyResponseCodes 255 // override api main CacheOnlyResponseCodes by endpoint specific if provided 256 if cacheMeta != nil && len(cacheMeta.CacheOnlyResponseCodes) > 0 { 257 cacheOnlyResponseCodes = cacheMeta.CacheOnlyResponseCodes 258 } 259 260 // make sure the status codes match if specified 261 if len(cacheOnlyResponseCodes) > 0 { 262 foundCode := false 263 for _, code := range cacheOnlyResponseCodes { 264 if code == resVal.StatusCode { 265 foundCode = true 266 break 267 } 268 } 269 cacheThisRequest = foundCode 270 } 271 272 // Are we using upstream cache control? 273 if m.Spec.CacheOptions.EnableUpstreamCacheControl { 274 log.Debug("Upstream control enabled") 275 // Do we cache? 276 if resVal.Header.Get(upstreamCacheHeader) == "" { 277 log.Warning("Upstream cache action not found, not caching") 278 cacheThisRequest = false 279 } 280 281 cacheTTLHeader := upstreamCacheTTLHeader 282 if m.Spec.CacheOptions.CacheControlTTLHeader != "" { 283 cacheTTLHeader = m.Spec.CacheOptions.CacheControlTTLHeader 284 } 285 286 ttl := resVal.Header.Get(cacheTTLHeader) 287 if ttl != "" { 288 log.Debug("TTL Set upstream") 289 cacheAsInt, err := strconv.Atoi(ttl) 290 if err != nil { 291 log.Error("Failed to decode TTL cache value: ", err) 292 cacheTTL = m.Spec.CacheOptions.CacheTimeout 293 } else { 294 cacheTTL = int64(cacheAsInt) 295 } 296 } 297 } 298 299 if cacheThisRequest && !errCreatingChecksum { 300 log.Debug("Caching request to redis") 301 var wireFormatReq bytes.Buffer 302 resVal.Write(&wireFormatReq) 303 log.Debug("Cache TTL is:", cacheTTL) 304 ts := m.getTimeTTL(cacheTTL) 305 toStore := m.encodePayload(wireFormatReq.String(), ts) 306 go m.CacheStore.SetKey(key, toStore, cacheTTL) 307 } 308 309 return nil, mwStatusRespond 310 } 311 312 cachedData, timestamp, err := m.decodePayload(retBlob) 313 if err != nil { 314 // Tere was an issue with this cache entry - lets remove it: 315 m.CacheStore.DeleteKey(key) 316 return nil, http.StatusOK 317 } 318 319 if m.isTimeStampExpired(timestamp) || len(cachedData) == 0 { 320 m.CacheStore.DeleteKey(key) 321 return nil, http.StatusOK 322 } 323 324 log.Debug("Cache got: ", cachedData) 325 bufData := bufio.NewReader(strings.NewReader(cachedData)) 326 newRes, err := http.ReadResponse(bufData, r) 327 if err != nil { 328 log.Error("Could not create response object: ", err) 329 } 330 nopCloseResponseBody(newRes) 331 332 defer newRes.Body.Close() 333 for _, h := range hopHeaders { 334 newRes.Header.Del(h) 335 } 336 337 copyHeader(w.Header(), newRes.Header) 338 session := ctxGetSession(r) 339 340 // Only add ratelimit data to keyed sessions 341 if session != nil { 342 quotaMax, quotaRemaining, _, quotaRenews := session.GetQuotaLimitByAPIID(m.Spec.APIID) 343 w.Header().Set(headers.XRateLimitLimit, strconv.Itoa(int(quotaMax))) 344 w.Header().Set(headers.XRateLimitRemaining, strconv.Itoa(int(quotaRemaining))) 345 w.Header().Set(headers.XRateLimitReset, strconv.Itoa(int(quotaRenews))) 346 } 347 w.Header().Set("x-tyk-cached-response", "1") 348 349 if reqEtag := r.Header.Get("If-None-Match"); reqEtag != "" { 350 if respEtag := newRes.Header.Get("Etag"); respEtag != "" { 351 if strings.Contains(reqEtag, respEtag) { 352 newRes.StatusCode = http.StatusNotModified 353 } 354 } 355 } 356 357 w.WriteHeader(newRes.StatusCode) 358 if newRes.StatusCode != http.StatusNotModified { 359 m.Proxy.CopyResponse(w, newRes.Body) 360 } 361 362 // Record analytics 363 if !m.Spec.DoNotTrack { 364 m.sh.RecordHit(r, Latency{}, newRes.StatusCode, newRes) 365 } 366 367 // Stop any further execution 368 return nil, mwStatusRespond 369 } 370 371 func isSafeMethod(method string) bool { 372 return method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions 373 } 374 375 func (m *RedisCacheMiddleware) getCacheKeyFromHeaders(r *http.Request) (key string) { 376 key = "" 377 for _, header := range m.Spec.CacheOptions.CacheByHeaders { 378 key += header + "-" + r.Header.Get(header) 379 } 380 return 381 }