github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/mw_redis_cache.go (about)

     1  package gateway
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/md5"
     7  	"encoding/base64"
     8  	"encoding/hex"
     9  	"errors"
    10  	"hash"
    11  	"io"
    12  	"io/ioutil"
    13  	"net/http"
    14  	"strconv"
    15  	"strings"
    16  	"time"
    17  
    18  	"golang.org/x/sync/singleflight"
    19  
    20  	"github.com/TykTechnologies/murmur3"
    21  	"github.com/TykTechnologies/tyk/headers"
    22  	"github.com/TykTechnologies/tyk/regexp"
    23  	"github.com/TykTechnologies/tyk/request"
    24  	"github.com/TykTechnologies/tyk/storage"
    25  )
    26  
    27  const (
    28  	upstreamCacheHeader    = "x-tyk-cache-action-set"
    29  	upstreamCacheTTLHeader = "x-tyk-cache-action-set-ttl"
    30  )
    31  
    32  // RedisCacheMiddleware is a caching middleware that will pull data from Redis instead of the upstream proxy
    33  type RedisCacheMiddleware struct {
    34  	BaseMiddleware
    35  	CacheStore   storage.Handler
    36  	sh           SuccessHandler
    37  	singleFlight singleflight.Group
    38  }
    39  
    40  func (m *RedisCacheMiddleware) Name() string {
    41  	return "RedisCacheMiddleware"
    42  }
    43  
    44  func (m *RedisCacheMiddleware) Init() {
    45  	m.sh = SuccessHandler{m.BaseMiddleware}
    46  }
    47  
    48  func (m *RedisCacheMiddleware) EnabledForSpec() bool {
    49  	return m.Spec.CacheOptions.EnableCache
    50  }
    51  
    52  func (m *RedisCacheMiddleware) CreateCheckSum(req *http.Request, keyName string, regex string, additionalKeyFromHeaders string) (string, error) {
    53  	h := md5.New()
    54  	io.WriteString(h, req.Method)
    55  	io.WriteString(h, "-"+req.URL.String())
    56  	if additionalKeyFromHeaders != "" {
    57  		io.WriteString(h, "-"+additionalKeyFromHeaders)
    58  	}
    59  
    60  	if e := addBodyHash(req, regex, h); e != nil {
    61  		return "", e
    62  	}
    63  
    64  	reqChecksum := hex.EncodeToString(h.Sum(nil))
    65  	return m.Spec.APIID + keyName + reqChecksum, nil
    66  }
    67  
    68  func addBodyHash(req *http.Request, regex string, h hash.Hash) (err error) {
    69  	if !isBodyHashRequired(req) {
    70  		return nil
    71  	}
    72  
    73  	bodyBytes, err := readBody(req)
    74  	if err != nil {
    75  		return err
    76  	}
    77  
    78  	mur := murmur3.New128()
    79  	if regex == "" {
    80  		mur.Write(bodyBytes)
    81  		io.WriteString(h, "-"+hex.EncodeToString(mur.Sum(nil)))
    82  		return nil
    83  	}
    84  	r, err := regexp.Compile(regex)
    85  	if err != nil {
    86  		return err
    87  	}
    88  
    89  	if match := r.Find(bodyBytes); match != nil {
    90  		mur.Write(match)
    91  		io.WriteString(h, "-"+hex.EncodeToString(mur.Sum(nil)))
    92  	}
    93  	return nil
    94  }
    95  
    96  func readBody(req *http.Request) (bodyBytes []byte, err error) {
    97  	if n, ok := req.Body.(nopCloser); ok {
    98  		n.Seek(0, io.SeekStart)
    99  		bodyBytes, err = ioutil.ReadAll(n)
   100  		if err != nil {
   101  			return nil, err
   102  		}
   103  		n.Seek(0, io.SeekStart) // reset for any next read.
   104  		return
   105  	}
   106  
   107  	req.Body = copyBody(req.Body)
   108  	bodyBytes, err = ioutil.ReadAll(req.Body)
   109  	if err != nil {
   110  		return nil, err
   111  	}
   112  	req.Body.(nopCloser).Seek(0, io.SeekStart) // reset for any next read.
   113  	return
   114  }
   115  
   116  func isBodyHashRequired(request *http.Request) bool {
   117  	return request.Body != nil &&
   118  		(request.Method == http.MethodPost ||
   119  			request.Method == http.MethodPut ||
   120  			request.Method == http.MethodPatch)
   121  
   122  }
   123  
   124  func (m *RedisCacheMiddleware) getTimeTTL(cacheTTL int64) string {
   125  	timeNow := time.Now().Unix()
   126  	newTTL := timeNow + cacheTTL
   127  	asStr := strconv.Itoa(int(newTTL))
   128  	return asStr
   129  }
   130  
   131  func (m *RedisCacheMiddleware) isTimeStampExpired(timestamp string) bool {
   132  	now := time.Now()
   133  
   134  	i, err := strconv.ParseInt(timestamp, 10, 64)
   135  	if err != nil {
   136  		log.Error(err)
   137  	}
   138  	tm := time.Unix(i, 0)
   139  
   140  	log.Debug("Time Now: ", now)
   141  	log.Debug("Expires: ", tm)
   142  	if tm.Before(now) {
   143  		log.Debug("Expriy caught in TS!")
   144  		return true
   145  	}
   146  
   147  	return false
   148  }
   149  
   150  func (m *RedisCacheMiddleware) encodePayload(payload, timestamp string) string {
   151  	sEnc := base64.StdEncoding.EncodeToString([]byte(payload))
   152  	return sEnc + "|" + timestamp
   153  }
   154  
   155  func (m *RedisCacheMiddleware) decodePayload(payload string) (string, string, error) {
   156  	data := strings.Split(payload, "|")
   157  	switch len(data) {
   158  	case 1:
   159  		return data[0], "", nil
   160  	case 2:
   161  		sDec, err := base64.StdEncoding.DecodeString(data[0])
   162  		if err != nil {
   163  			return "", "", err
   164  		}
   165  
   166  		return string(sDec), data[1], nil
   167  	}
   168  	return "", "", errors.New("Decoding failed, array length wrong")
   169  }
   170  
   171  // ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail
   172  func (m *RedisCacheMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _ interface{}) (error, int) {
   173  	var stat RequestStatus
   174  	var cacheKeyRegex string
   175  	var cacheMeta *EndPointCacheMeta
   176  
   177  	_, versionPaths, _, _ := m.Spec.Version(r)
   178  	isVirtual, _ := m.Spec.CheckSpecMatchesStatus(r, versionPaths, VirtualPath)
   179  
   180  	// Lets see if we can throw a sledgehammer at this
   181  	if m.Spec.CacheOptions.CacheAllSafeRequests && isSafeMethod(r.Method) {
   182  		stat = StatusCached
   183  	}
   184  	if stat != StatusCached {
   185  		// New request checker, more targeted, less likely to fail
   186  		found, meta := m.Spec.CheckSpecMatchesStatus(r, versionPaths, Cached)
   187  		if found {
   188  			cacheMeta = meta.(*EndPointCacheMeta)
   189  			stat = StatusCached
   190  			cacheKeyRegex = cacheMeta.CacheKeyRegex
   191  		}
   192  	}
   193  
   194  	// Cached route matched, let go
   195  	if stat != StatusCached {
   196  		return nil, http.StatusOK
   197  	}
   198  	token := ctxGetAuthToken(r)
   199  
   200  	// No authentication data? use the IP.
   201  	if token == "" {
   202  		token = request.RealIP(r)
   203  	}
   204  
   205  	var errCreatingChecksum bool
   206  	var retBlob string
   207  	key, err := m.CreateCheckSum(r, token, cacheKeyRegex, m.getCacheKeyFromHeaders(r))
   208  	if err != nil {
   209  		log.Debug("Error creating checksum. Skipping cache check")
   210  		errCreatingChecksum = true
   211  	} else {
   212  		v, sfErr, _ := m.singleFlight.Do(key, func() (interface{}, error) {
   213  			return m.CacheStore.GetKey(key)
   214  		})
   215  		retBlob = v.(string)
   216  		err = sfErr
   217  	}
   218  
   219  	if err != nil {
   220  		if !errCreatingChecksum {
   221  			log.Debug("Cache enabled, but record not found")
   222  		}
   223  		// Pass through to proxy AND CACHE RESULT
   224  
   225  		var resVal *http.Response
   226  		if isVirtual {
   227  			log.Debug("This is a virtual function")
   228  			vp := VirtualEndpoint{BaseMiddleware: m.BaseMiddleware}
   229  			vp.Init()
   230  			resVal = vp.ServeHTTPForCache(w, r, nil)
   231  		} else {
   232  			// This passes through and will write the value to the writer, but spit out a copy for the cache
   233  			log.Debug("Not virtual, passing")
   234  			if newURL := ctxGetURLRewriteTarget(r); newURL != nil {
   235  				r.URL = newURL
   236  				ctxSetURLRewriteTarget(r, nil)
   237  			}
   238  			if newMethod := ctxGetTransformRequestMethod(r); newMethod != "" {
   239  				r.Method = newMethod
   240  				ctxSetTransformRequestMethod(r, "")
   241  			}
   242  			sr := m.sh.ServeHTTPWithCache(w, r)
   243  			resVal = sr.Response
   244  		}
   245  
   246  		cacheThisRequest := true
   247  		cacheTTL := m.Spec.CacheOptions.CacheTimeout
   248  
   249  		if resVal == nil {
   250  			log.Warning("Upstream request must have failed, response is empty")
   251  			return nil, mwStatusRespond
   252  		}
   253  
   254  		cacheOnlyResponseCodes := m.Spec.CacheOptions.CacheOnlyResponseCodes
   255  		// override api main CacheOnlyResponseCodes by endpoint specific if provided
   256  		if cacheMeta != nil && len(cacheMeta.CacheOnlyResponseCodes) > 0 {
   257  			cacheOnlyResponseCodes = cacheMeta.CacheOnlyResponseCodes
   258  		}
   259  
   260  		// make sure the status codes match if specified
   261  		if len(cacheOnlyResponseCodes) > 0 {
   262  			foundCode := false
   263  			for _, code := range cacheOnlyResponseCodes {
   264  				if code == resVal.StatusCode {
   265  					foundCode = true
   266  					break
   267  				}
   268  			}
   269  			cacheThisRequest = foundCode
   270  		}
   271  
   272  		// Are we using upstream cache control?
   273  		if m.Spec.CacheOptions.EnableUpstreamCacheControl {
   274  			log.Debug("Upstream control enabled")
   275  			// Do we cache?
   276  			if resVal.Header.Get(upstreamCacheHeader) == "" {
   277  				log.Warning("Upstream cache action not found, not caching")
   278  				cacheThisRequest = false
   279  			}
   280  
   281  			cacheTTLHeader := upstreamCacheTTLHeader
   282  			if m.Spec.CacheOptions.CacheControlTTLHeader != "" {
   283  				cacheTTLHeader = m.Spec.CacheOptions.CacheControlTTLHeader
   284  			}
   285  
   286  			ttl := resVal.Header.Get(cacheTTLHeader)
   287  			if ttl != "" {
   288  				log.Debug("TTL Set upstream")
   289  				cacheAsInt, err := strconv.Atoi(ttl)
   290  				if err != nil {
   291  					log.Error("Failed to decode TTL cache value: ", err)
   292  					cacheTTL = m.Spec.CacheOptions.CacheTimeout
   293  				} else {
   294  					cacheTTL = int64(cacheAsInt)
   295  				}
   296  			}
   297  		}
   298  
   299  		if cacheThisRequest && !errCreatingChecksum {
   300  			log.Debug("Caching request to redis")
   301  			var wireFormatReq bytes.Buffer
   302  			resVal.Write(&wireFormatReq)
   303  			log.Debug("Cache TTL is:", cacheTTL)
   304  			ts := m.getTimeTTL(cacheTTL)
   305  			toStore := m.encodePayload(wireFormatReq.String(), ts)
   306  			go m.CacheStore.SetKey(key, toStore, cacheTTL)
   307  		}
   308  
   309  		return nil, mwStatusRespond
   310  	}
   311  
   312  	cachedData, timestamp, err := m.decodePayload(retBlob)
   313  	if err != nil {
   314  		// Tere was an issue with this cache entry - lets remove it:
   315  		m.CacheStore.DeleteKey(key)
   316  		return nil, http.StatusOK
   317  	}
   318  
   319  	if m.isTimeStampExpired(timestamp) || len(cachedData) == 0 {
   320  		m.CacheStore.DeleteKey(key)
   321  		return nil, http.StatusOK
   322  	}
   323  
   324  	log.Debug("Cache got: ", cachedData)
   325  	bufData := bufio.NewReader(strings.NewReader(cachedData))
   326  	newRes, err := http.ReadResponse(bufData, r)
   327  	if err != nil {
   328  		log.Error("Could not create response object: ", err)
   329  	}
   330  	nopCloseResponseBody(newRes)
   331  
   332  	defer newRes.Body.Close()
   333  	for _, h := range hopHeaders {
   334  		newRes.Header.Del(h)
   335  	}
   336  
   337  	copyHeader(w.Header(), newRes.Header)
   338  	session := ctxGetSession(r)
   339  
   340  	// Only add ratelimit data to keyed sessions
   341  	if session != nil {
   342  		quotaMax, quotaRemaining, _, quotaRenews := session.GetQuotaLimitByAPIID(m.Spec.APIID)
   343  		w.Header().Set(headers.XRateLimitLimit, strconv.Itoa(int(quotaMax)))
   344  		w.Header().Set(headers.XRateLimitRemaining, strconv.Itoa(int(quotaRemaining)))
   345  		w.Header().Set(headers.XRateLimitReset, strconv.Itoa(int(quotaRenews)))
   346  	}
   347  	w.Header().Set("x-tyk-cached-response", "1")
   348  
   349  	if reqEtag := r.Header.Get("If-None-Match"); reqEtag != "" {
   350  		if respEtag := newRes.Header.Get("Etag"); respEtag != "" {
   351  			if strings.Contains(reqEtag, respEtag) {
   352  				newRes.StatusCode = http.StatusNotModified
   353  			}
   354  		}
   355  	}
   356  
   357  	w.WriteHeader(newRes.StatusCode)
   358  	if newRes.StatusCode != http.StatusNotModified {
   359  		m.Proxy.CopyResponse(w, newRes.Body)
   360  	}
   361  
   362  	// Record analytics
   363  	if !m.Spec.DoNotTrack {
   364  		m.sh.RecordHit(r, Latency{}, newRes.StatusCode, newRes)
   365  	}
   366  
   367  	// Stop any further execution
   368  	return nil, mwStatusRespond
   369  }
   370  
   371  func isSafeMethod(method string) bool {
   372  	return method == http.MethodGet || method == http.MethodHead || method == http.MethodOptions
   373  }
   374  
   375  func (m *RedisCacheMiddleware) getCacheKeyFromHeaders(r *http.Request) (key string) {
   376  	key = ""
   377  	for _, header := range m.Spec.CacheOptions.CacheByHeaders {
   378  		key += header + "-" + r.Header.Get(header)
   379  	}
   380  	return
   381  }