github.com/dbernstein1/tyk@v2.9.0-beta9-dl-apic+incompatible/gateway/mw_redis_cache.go (about)

     1  package gateway
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"crypto/md5"
     7  	"encoding/base64"
     8  	"encoding/hex"
     9  	"errors"
    10  	"io"
    11  	"io/ioutil"
    12  	"net/http"
    13  	"strconv"
    14  	"strings"
    15  	"time"
    16  
    17  	"golang.org/x/sync/singleflight"
    18  
    19  	"github.com/TykTechnologies/murmur3"
    20  	"github.com/TykTechnologies/tyk/regexp"
    21  	"github.com/TykTechnologies/tyk/request"
    22  	"github.com/TykTechnologies/tyk/storage"
    23  )
    24  
    25  const (
    26  	upstreamCacheHeader    = "x-tyk-cache-action-set"
    27  	upstreamCacheTTLHeader = "x-tyk-cache-action-set-ttl"
    28  )
    29  
    30  // RedisCacheMiddleware is a caching middleware that will pull data from Redis instead of the upstream proxy
    31  type RedisCacheMiddleware struct {
    32  	BaseMiddleware
    33  	CacheStore   storage.Handler
    34  	sh           SuccessHandler
    35  	singleFlight singleflight.Group
    36  }
    37  
    38  func (m *RedisCacheMiddleware) Name() string {
    39  	return "RedisCacheMiddleware"
    40  }
    41  
    42  func (m *RedisCacheMiddleware) Init() {
    43  	m.sh = SuccessHandler{m.BaseMiddleware}
    44  }
    45  
    46  func (m *RedisCacheMiddleware) EnabledForSpec() bool {
    47  	return m.Spec.CacheOptions.EnableCache
    48  }
    49  
    50  func (m *RedisCacheMiddleware) CreateCheckSum(req *http.Request, keyName string, regex string) (string, error) {
    51  	h := md5.New()
    52  	io.WriteString(h, req.Method)
    53  	io.WriteString(h, "-")
    54  	io.WriteString(h, req.URL.String())
    55  	if req.Method == http.MethodPost {
    56  		if req.Body != nil {
    57  			bodyBytes, err := ioutil.ReadAll(req.Body)
    58  
    59  			if err != nil {
    60  				return "", err
    61  			}
    62  
    63  			defer req.Body.Close()
    64  			req.Body = ioutil.NopCloser(bytes.NewBuffer(bodyBytes))
    65  
    66  			m := murmur3.New128()
    67  			if regex == "" {
    68  				io.WriteString(h, "-")
    69  				m.Write(bodyBytes)
    70  				io.WriteString(h, hex.EncodeToString(m.Sum(nil)))
    71  			} else {
    72  				r, err := regexp.Compile(regex)
    73  				if err != nil {
    74  					return "", err
    75  				}
    76  				match := r.Find(bodyBytes)
    77  				if match != nil {
    78  					io.WriteString(h, "-")
    79  					m.Write(match)
    80  					io.WriteString(h, hex.EncodeToString(m.Sum(nil)))
    81  				}
    82  			}
    83  		}
    84  	}
    85  
    86  	reqChecksum := hex.EncodeToString(h.Sum(nil))
    87  	return m.Spec.APIID + keyName + reqChecksum, nil
    88  }
    89  
    90  func (m *RedisCacheMiddleware) getTimeTTL(cacheTTL int64) string {
    91  	timeNow := time.Now().Unix()
    92  	newTTL := timeNow + cacheTTL
    93  	asStr := strconv.Itoa(int(newTTL))
    94  	return asStr
    95  }
    96  
    97  func (m *RedisCacheMiddleware) isTimeStampExpired(timestamp string) bool {
    98  	now := time.Now()
    99  
   100  	i, err := strconv.ParseInt(timestamp, 10, 64)
   101  	if err != nil {
   102  		log.Error(err)
   103  	}
   104  	tm := time.Unix(i, 0)
   105  
   106  	log.Debug("Time Now: ", now)
   107  	log.Debug("Expires: ", tm)
   108  	if tm.Before(now) {
   109  		log.Debug("Expriy caught in TS!")
   110  		return true
   111  	}
   112  
   113  	return false
   114  }
   115  
   116  func (m *RedisCacheMiddleware) encodePayload(payload, timestamp string) string {
   117  	sEnc := base64.StdEncoding.EncodeToString([]byte(payload))
   118  	return sEnc + "|" + timestamp
   119  }
   120  
   121  func (m *RedisCacheMiddleware) decodePayload(payload string) (string, string, error) {
   122  	data := strings.Split(payload, "|")
   123  	switch len(data) {
   124  	case 1:
   125  		return data[0], "", nil
   126  	case 2:
   127  		sDec, err := base64.StdEncoding.DecodeString(data[0])
   128  		if err != nil {
   129  			return "", "", err
   130  		}
   131  
   132  		return string(sDec), data[1], nil
   133  	}
   134  	return "", "", errors.New("Decoding failed, array length wrong")
   135  }
   136  
   137  // ProcessRequest will run any checks on the request on the way through the system, return an error to have the chain fail
   138  func (m *RedisCacheMiddleware) ProcessRequest(w http.ResponseWriter, r *http.Request, _ interface{}) (error, int) {
   139  	// Only allow idempotent (safe) methods
   140  	if r.Method != "GET" && r.Method != "HEAD" && r.Method != "OPTIONS" && r.Method != "POST" {
   141  		return nil, http.StatusOK
   142  	}
   143  
   144  	var stat RequestStatus
   145  	var cacheKeyRegex string
   146  
   147  	_, versionPaths, _, _ := m.Spec.Version(r)
   148  	isVirtual, _ := m.Spec.CheckSpecMatchesStatus(r, versionPaths, VirtualPath)
   149  
   150  	// Lets see if we can throw a sledgehammer at this
   151  	if m.Spec.CacheOptions.CacheAllSafeRequests && r.Method != "POST" {
   152  		stat = StatusCached
   153  	}
   154  	if stat != StatusCached {
   155  		// New request checker, more targeted, less likely to fail
   156  		found, meta := m.Spec.CheckSpecMatchesStatus(r, versionPaths, Cached)
   157  		if found {
   158  			cacheMeta := meta.(*EndPointCacheMeta)
   159  			stat = StatusCached
   160  			cacheKeyRegex = cacheMeta.CacheKeyRegex
   161  		}
   162  	}
   163  
   164  	// Cached route matched, let go
   165  	if stat != StatusCached {
   166  		return nil, http.StatusOK
   167  	}
   168  	token := ctxGetAuthToken(r)
   169  
   170  	// No authentication data? use the IP.
   171  	if token == "" {
   172  		token = request.RealIP(r)
   173  	}
   174  
   175  	var errCreatingChecksum bool
   176  	var retBlob string
   177  	key, err := m.CreateCheckSum(r, token, cacheKeyRegex)
   178  	if err != nil {
   179  		log.Debug("Error creating checksum. Skipping cache check")
   180  		errCreatingChecksum = true
   181  	} else {
   182  		retBlob, err = m.CacheStore.GetKey(key)
   183  		v, sfErr, _ := m.singleFlight.Do(key, func() (interface{}, error) {
   184  			return m.CacheStore.GetKey(key)
   185  		})
   186  		retBlob = v.(string)
   187  		err = sfErr
   188  	}
   189  
   190  	if err != nil {
   191  		if !errCreatingChecksum {
   192  			log.Debug("Cache enabled, but record not found")
   193  		}
   194  		// Pass through to proxy AND CACHE RESULT
   195  
   196  		var resVal *http.Response
   197  		if isVirtual {
   198  			log.Debug("This is a virtual function")
   199  			vp := VirtualEndpoint{BaseMiddleware: m.BaseMiddleware}
   200  			vp.Init()
   201  			resVal = vp.ServeHTTPForCache(w, r, nil)
   202  		} else {
   203  			// This passes through and will write the value to the writer, but spit out a copy for the cache
   204  			log.Debug("Not virtual, passing")
   205  			resVal = m.sh.ServeHTTPWithCache(w, r)
   206  		}
   207  
   208  		cacheThisRequest := true
   209  		cacheTTL := m.Spec.CacheOptions.CacheTimeout
   210  
   211  		if resVal == nil {
   212  			log.Warning("Upstream request must have failed, response is empty")
   213  			return nil, http.StatusOK
   214  		}
   215  
   216  		// make sure the status codes match if specified
   217  		if len(m.Spec.CacheOptions.CacheOnlyResponseCodes) > 0 {
   218  			foundCode := false
   219  			for _, code := range m.Spec.CacheOptions.CacheOnlyResponseCodes {
   220  				if code == resVal.StatusCode {
   221  					foundCode = true
   222  					break
   223  				}
   224  			}
   225  			if !foundCode {
   226  				cacheThisRequest = false
   227  			}
   228  		}
   229  
   230  		// Are we using upstream cache control?
   231  		if m.Spec.CacheOptions.EnableUpstreamCacheControl {
   232  			log.Debug("Upstream control enabled")
   233  			// Do we cache?
   234  			if resVal.Header.Get(upstreamCacheHeader) == "" {
   235  				log.Warning("Upstream cache action not found, not caching")
   236  				cacheThisRequest = false
   237  			}
   238  
   239  			cacheTTLHeader := upstreamCacheTTLHeader
   240  			if m.Spec.CacheOptions.CacheControlTTLHeader != "" {
   241  				cacheTTLHeader = m.Spec.CacheOptions.CacheControlTTLHeader
   242  			}
   243  
   244  			ttl := resVal.Header.Get(cacheTTLHeader)
   245  			if ttl != "" {
   246  				log.Debug("TTL Set upstream")
   247  				cacheAsInt, err := strconv.Atoi(ttl)
   248  				if err != nil {
   249  					log.Error("Failed to decode TTL cache value: ", err)
   250  					cacheTTL = m.Spec.CacheOptions.CacheTimeout
   251  				} else {
   252  					cacheTTL = int64(cacheAsInt)
   253  				}
   254  			}
   255  		}
   256  
   257  		if cacheThisRequest && !errCreatingChecksum {
   258  			log.Debug("Caching request to redis")
   259  			var wireFormatReq bytes.Buffer
   260  			resVal.Write(&wireFormatReq)
   261  			log.Debug("Cache TTL is:", cacheTTL)
   262  			ts := m.getTimeTTL(cacheTTL)
   263  			toStore := m.encodePayload(wireFormatReq.String(), ts)
   264  			go m.CacheStore.SetKey(key, toStore, cacheTTL)
   265  		}
   266  
   267  		return nil, mwStatusRespond
   268  	}
   269  
   270  	cachedData, timestamp, err := m.decodePayload(retBlob)
   271  	if err != nil {
   272  		// Tere was an issue with this cache entry - lets remove it:
   273  		m.CacheStore.DeleteKey(key)
   274  		return nil, http.StatusOK
   275  	}
   276  
   277  	if m.isTimeStampExpired(timestamp) || len(cachedData) == 0 {
   278  		m.CacheStore.DeleteKey(key)
   279  		return nil, http.StatusOK
   280  	}
   281  
   282  	log.Debug("Cache got: ", cachedData)
   283  	bufData := bufio.NewReader(strings.NewReader(cachedData))
   284  	newRes, err := http.ReadResponse(bufData, r)
   285  	if err != nil {
   286  		log.Error("Could not create response object: ", err)
   287  	}
   288  	nopCloseResponseBody(newRes)
   289  
   290  	defer newRes.Body.Close()
   291  	for _, h := range hopHeaders {
   292  		newRes.Header.Del(h)
   293  	}
   294  
   295  	copyHeader(w.Header(), newRes.Header)
   296  	session := ctxGetSession(r)
   297  
   298  	// Only add ratelimit data to keyed sessions
   299  	if session != nil {
   300  		quotaMax, quotaRemaining, _, quotaRenews := session.GetQuotaLimitByAPIID(m.Spec.APIID)
   301  		w.Header().Set(XRateLimitLimit, strconv.Itoa(int(quotaMax)))
   302  		w.Header().Set(XRateLimitRemaining, strconv.Itoa(int(quotaRemaining)))
   303  		w.Header().Set(XRateLimitReset, strconv.Itoa(int(quotaRenews)))
   304  	}
   305  	w.Header().Set("x-tyk-cached-response", "1")
   306  
   307  	if reqEtag := r.Header.Get("If-None-Match"); reqEtag != "" {
   308  		if respEtag := newRes.Header.Get("Etag"); respEtag != "" {
   309  			if strings.Contains(reqEtag, respEtag) {
   310  				newRes.StatusCode = http.StatusNotModified
   311  			}
   312  		}
   313  	}
   314  
   315  	w.WriteHeader(newRes.StatusCode)
   316  	if newRes.StatusCode != http.StatusNotModified {
   317  		m.Proxy.CopyResponse(w, newRes.Body)
   318  	}
   319  
   320  	// Record analytics
   321  	if !m.Spec.DoNotTrack {
   322  		m.sh.RecordHit(r, 0, newRes.StatusCode, newRes)
   323  	}
   324  
   325  	// Stop any further execution
   326  	return nil, mwStatusRespond
   327  }