github.com/dbernstein1/tyk@v2.9.0-beta9-dl-apic+incompatible/gateway/mw_redis_cache.go (about) 1 package gateway 2 3 import ( 4 "bufio" 5 "bytes" 6 "crypto/md5" 7 "encoding/base64" 8 "encoding/hex" 9 "errors" 10 "io" 11 "io/ioutil" 12 "net/http" 13 "strconv" 14 "strings" 15 "time" 16 17 "golang.org/x/sync/singleflight" 18 19 "github.com/TykTechnologies/murmur3" 20 "github.com/TykTechnologies/tyk/regexp" 21 "github.com/TykTechnologies/tyk/request" 22 "github.com/TykTechnologies/tyk/storage" 23 ) 24 25 const ( 26 upstreamCacheHeader = "x-tyk-cache-action-set" 27 upstreamCacheTTLHeader = "x-tyk-cache-action-set-ttl" 28 ) 29 30 // RedisCacheMiddleware is a caching middleware that will pull data from Redis instead of the upstream proxy 31 type RedisCacheMiddleware struct { 32 BaseMiddleware 33 CacheStore storage.Handler 34 sh SuccessHandler 35 singleFlight singleflight.Group 36 } 37 38 func (m *RedisCacheMiddleware) Name() string { 39 return "RedisCacheMiddleware" 40 } 41 42 func (m *RedisCacheMiddleware) Init() { 43 m.sh = SuccessHandler{m.BaseMiddleware} 44 } 45 46 func (m *RedisCacheMiddleware) EnabledForSpec() bool { 47 return m.Spec.CacheOptions.EnableCache 48 } 49 50 func (m *RedisCacheMiddleware) CreateCheckSum(req *http.Request, keyName string, regex string) (string, error) { 51 h := md5.New() 52 io.WriteString(h, req.Method) 53 io.WriteString(h, "-") 54 io.WriteString(h, req.URL.String()) 55 if req.Method == http.MethodPost { 56 if req.Body != nil { 57 bodyBytes, err := ioutil.ReadAll(req.Body) 58 59 if err != nil { 60 return "", err 61 } 62 63 defer req.Body.Close() 64 req.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes)) 65 66 m := murmur3.New128() 67 if regex == "" { 68 io.WriteString(h, "-") 69 m.Write(bodyBytes) 70 io.WriteString(h, hex.EncodeToString(m.Sum(nil))) 71 } else { 72 r, err := regexp.Compile(regex) 73 if err != nil { 74 return "", err 75 } 76 match := r.Find(bodyBytes) 77 if match != nil { 78 io.WriteString(h, "-") 79 m.Write(match) 80 io.WriteString(h, hex.EncodeToString(m.Sum(nil))) 81 } 82 } 83 } 84 } 85 86 reqChecksum := hex.EncodeToString(h.Sum(nil)) 87 return m.Spec.APIID + keyName + reqChecksum, nil 88 } 89 90 func (m *RedisCacheMiddleware) getTimeTTL(cacheTTL int64) string { 91 timeNow := time.Now().Unix() 92 newTTL := timeNow + cacheTTL 93 asStr := strconv.Itoa(int(newTTL)) 94 return asStr 95 } 96 97 func (m *RedisCacheMiddleware) isTimeStampExpired(timestamp string) bool { 98 now := time.Now() 99 100 i, err := strconv.ParseInt(timestamp, 10, 64) 101 if err != nil { 102 log.Error(err) 103 } 104 tm := time.Unix(i, 0) 105 106 log.Debug("Time Now: ", now) 107 log.Debug("Expires: ", tm) 108 if tm.Before(now) { 109 log.Debug("Expriy caught in TS!") 110 return true 111 } 112 113 return false 114 } 115 116 func (m *RedisCacheMiddleware) encodePayload(payload, timestamp string) string { 117 sEnc := base64.StdEncoding.EncodeToString([]byte(payload)) 118 return sEnc + "|" + timestamp 119 } 120 121 func (m *RedisCacheMiddleware) decodePayload(payload string) (string, string, error) { 122 data := strings.Split(payload, "|") 123 switch len(data) { 124 case 1: 125 return data[0], "", nil 126 case 2: 127 sDec, err := base64.StdEncoding.DecodeString(data[0]) 128 if err != nil { 129 return "", "", err 130 } 131 132 return string(sDec), data[1], nil 133 } 134 return "", "", errors.New("Decoding failed, array length wrong") 135 } 136 137 // ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail 138 func (m *RedisCacheMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _ interface{}) (error, int) { 139 // Only allow idempotent (safe) methods 140 if r.Method != "GET" && r.Method != "HEAD" && r.Method != "OPTIONS" && r.Method != "POST" { 141 return nil, http.StatusOK 142 } 143 144 var stat RequestStatus 145 var cacheKeyRegex string 146 147 _, versionPaths, _, _ := m.Spec.Version(r) 148 isVirtual, _ := m.Spec.CheckSpecMatchesStatus(r, versionPaths, VirtualPath) 149 150 // Lets see if we can throw a sledgehammer at this 151 if m.Spec.CacheOptions.CacheAllSafeRequests && r.Method != "POST" { 152 stat = StatusCached 153 } 154 if stat != StatusCached { 155 // New request checker, more targeted, less likely to fail 156 found, meta := m.Spec.CheckSpecMatchesStatus(r, versionPaths, Cached) 157 if found { 158 cacheMeta := meta.(*EndPointCacheMeta) 159 stat = StatusCached 160 cacheKeyRegex = cacheMeta.CacheKeyRegex 161 } 162 } 163 164 // Cached route matched, let go 165 if stat != StatusCached { 166 return nil, http.StatusOK 167 } 168 token := ctxGetAuthToken(r) 169 170 // No authentication data? use the IP. 171 if token == "" { 172 token = request.RealIP(r) 173 } 174 175 var errCreatingChecksum bool 176 var retBlob string 177 key, err := m.CreateCheckSum(r, token, cacheKeyRegex) 178 if err != nil { 179 log.Debug("Error creating checksum. Skipping cache check") 180 errCreatingChecksum = true 181 } else { 182 retBlob, err = m.CacheStore.GetKey(key) 183 v, sfErr, _ := m.singleFlight.Do(key, func() (interface{}, error) { 184 return m.CacheStore.GetKey(key) 185 }) 186 retBlob = v.(string) 187 err = sfErr 188 } 189 190 if err != nil { 191 if !errCreatingChecksum { 192 log.Debug("Cache enabled, but record not found") 193 } 194 // Pass through to proxy AND CACHE RESULT 195 196 var resVal *http.Response 197 if isVirtual { 198 log.Debug("This is a virtual function") 199 vp := VirtualEndpoint{BaseMiddleware: m.BaseMiddleware} 200 vp.Init() 201 resVal = vp.ServeHTTPForCache(w, r, nil) 202 } else { 203 // This passes through and will write the value to the writer, but spit out a copy for the cache 204 log.Debug("Not virtual, passing") 205 resVal = m.sh.ServeHTTPWithCache(w, r) 206 } 207 208 cacheThisRequest := true 209 cacheTTL := m.Spec.CacheOptions.CacheTimeout 210 211 if resVal == nil { 212 log.Warning("Upstream request must have failed, response is empty") 213 return nil, http.StatusOK 214 } 215 216 // make sure the status codes match if specified 217 if len(m.Spec.CacheOptions.CacheOnlyResponseCodes) > 0 { 218 foundCode := false 219 for _, code := range m.Spec.CacheOptions.CacheOnlyResponseCodes { 220 if code == resVal.StatusCode { 221 foundCode = true 222 break 223 } 224 } 225 if !foundCode { 226 cacheThisRequest = false 227 } 228 } 229 230 // Are we using upstream cache control? 231 if m.Spec.CacheOptions.EnableUpstreamCacheControl { 232 log.Debug("Upstream control enabled") 233 // Do we cache? 234 if resVal.Header.Get(upstreamCacheHeader) == "" { 235 log.Warning("Upstream cache action not found, not caching") 236 cacheThisRequest = false 237 } 238 239 cacheTTLHeader := upstreamCacheTTLHeader 240 if m.Spec.CacheOptions.CacheControlTTLHeader != "" { 241 cacheTTLHeader = m.Spec.CacheOptions.CacheControlTTLHeader 242 } 243 244 ttl := resVal.Header.Get(cacheTTLHeader) 245 if ttl != "" { 246 log.Debug("TTL Set upstream") 247 cacheAsInt, err := strconv.Atoi(ttl) 248 if err != nil { 249 log.Error("Failed to decode TTL cache value: ", err) 250 cacheTTL = m.Spec.CacheOptions.CacheTimeout 251 } else { 252 cacheTTL = int64(cacheAsInt) 253 } 254 } 255 } 256 257 if cacheThisRequest && !errCreatingChecksum { 258 log.Debug("Caching request to redis") 259 var wireFormatReq bytes.Buffer 260 resVal.Write(&wireFormatReq) 261 log.Debug("Cache TTL is:", cacheTTL) 262 ts := m.getTimeTTL(cacheTTL) 263 toStore := m.encodePayload(wireFormatReq.String(), ts) 264 go m.CacheStore.SetKey(key, toStore, cacheTTL) 265 } 266 267 return nil, mwStatusRespond 268 } 269 270 cachedData, timestamp, err := m.decodePayload(retBlob) 271 if err != nil { 272 // Tere was an issue with this cache entry - lets remove it: 273 m.CacheStore.DeleteKey(key) 274 return nil, http.StatusOK 275 } 276 277 if m.isTimeStampExpired(timestamp) || len(cachedData) == 0 { 278 m.CacheStore.DeleteKey(key) 279 return nil, http.StatusOK 280 } 281 282 log.Debug("Cache got: ", cachedData) 283 bufData := bufio.NewReader(strings.NewReader(cachedData)) 284 newRes, err := http.ReadResponse(bufData, r) 285 if err != nil { 286 log.Error("Could not create response object: ", err) 287 } 288 nopCloseResponseBody(newRes) 289 290 defer newRes.Body.Close() 291 for _, h := range hopHeaders { 292 newRes.Header.Del(h) 293 } 294 295 copyHeader(w.Header(), newRes.Header) 296 session := ctxGetSession(r) 297 298 // Only add ratelimit data to keyed sessions 299 if session != nil { 300 quotaMax, quotaRemaining, _, quotaRenews := session.GetQuotaLimitByAPIID(m.Spec.APIID) 301 w.Header().Set(XRateLimitLimit, strconv.Itoa(int(quotaMax))) 302 w.Header().Set(XRateLimitRemaining, strconv.Itoa(int(quotaRemaining))) 303 w.Header().Set(XRateLimitReset, strconv.Itoa(int(quotaRenews))) 304 } 305 w.Header().Set("x-tyk-cached-response", "1") 306 307 if reqEtag := r.Header.Get("If-None-Match"); reqEtag != "" { 308 if respEtag := newRes.Header.Get("Etag"); respEtag != "" { 309 if strings.Contains(reqEtag, respEtag) { 310 newRes.StatusCode = http.StatusNotModified 311 } 312 } 313 } 314 315 w.WriteHeader(newRes.StatusCode) 316 if newRes.StatusCode != http.StatusNotModified { 317 m.Proxy.CopyResponse(w, newRes.Body) 318 } 319 320 // Record analytics 321 if !m.Spec.DoNotTrack { 322 m.sh.RecordHit(r, 0, newRes.StatusCode, newRes) 323 } 324 325 // Stop any further execution 326 return nil, mwStatusRespond 327 }