github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/handler_success.go (about)

     1  package gateway
     2  
     3  import (
     4  	"bytes"
     5  	"encoding/base64"
     6  	"io"
     7  	"io/ioutil"
     8  	"net/http"
     9  	"runtime/pprof"
    10  	"strconv"
    11  	"strings"
    12  	"time"
    13  
    14  	cache "github.com/pmylund/go-cache"
    15  
    16  	"github.com/TykTechnologies/tyk/config"
    17  	"github.com/TykTechnologies/tyk/ctx"
    18  	"github.com/TykTechnologies/tyk/headers"
    19  	"github.com/TykTechnologies/tyk/request"
    20  	"github.com/TykTechnologies/tyk/user"
    21  )
    22  
    23  const (
    24  	keyDataDeveloperID    = "tyk_developer_id"
    25  	keyDataDeveloperEmail = "tyk_developer_email"
    26  )
    27  
    28  var (
    29  	// key session memory cache
    30  	SessionCache = cache.New(10*time.Second, 5*time.Second)
    31  
    32  	// org session memory cache
    33  	ExpiryCache = cache.New(600*time.Second, 10*time.Minute)
    34  
    35  	// memory cache to store arbitrary items
    36  	UtilCache = cache.New(time.Hour, 10*time.Minute)
    37  )
    38  
    39  type ProxyResponse struct {
    40  	Response *http.Response
    41  	// UpstreamLatency the time it takes to do roundtrip to upstream. Total time
    42  	// taken for the gateway to receive response from upstream host.
    43  	UpstreamLatency time.Duration
    44  }
    45  
    46  type ReturningHttpHandler interface {
    47  	ServeHTTP(http.ResponseWriter, *http.Request) ProxyResponse
    48  	ServeHTTPForCache(http.ResponseWriter, *http.Request) ProxyResponse
    49  	CopyResponse(io.Writer, io.Reader)
    50  }
    51  
    52  // SuccessHandler represents the final ServeHTTP() request for a proxied API request
    53  type SuccessHandler struct {
    54  	BaseMiddleware
    55  }
    56  
    57  func tagHeaders(r *http.Request, th []string, tags []string) []string {
    58  	for k, v := range r.Header {
    59  		cleanK := strings.ToLower(k)
    60  		ok := false
    61  		for _, hname := range th {
    62  			if hname == cleanK {
    63  				ok = true
    64  				break
    65  			}
    66  		}
    67  
    68  		if ok {
    69  			for _, val := range v {
    70  				tagName := cleanK + "-" + val
    71  				tags = append(tags, tagName)
    72  			}
    73  		}
    74  	}
    75  
    76  	return tags
    77  }
    78  
    79  func addVersionHeader(w http.ResponseWriter, r *http.Request, globalConf config.Config) {
    80  	if ctxGetDefaultVersion(r) {
    81  		if vinfo := ctxGetVersionInfo(r); vinfo != nil {
    82  			if globalConf.VersionHeader != "" {
    83  				w.Header().Set(globalConf.VersionHeader, vinfo.Name)
    84  			}
    85  		}
    86  	}
    87  }
    88  
    89  func estimateTagsCapacity(session *user.SessionState, apiSpec *APISpec) int {
    90  	size := 5 // that number of tags expected to be added at least before we record hit
    91  	if session != nil {
    92  		size += len(session.Tags)
    93  
    94  		size += len(session.ApplyPolicies)
    95  
    96  		if session.GetMetaData() != nil {
    97  			if _, ok := session.GetMetaDataByKey(keyDataDeveloperID); ok {
    98  				size += 1
    99  			}
   100  		}
   101  	}
   102  
   103  	if apiSpec.GlobalConfig.DBAppConfOptions.NodeIsSegmented {
   104  		size += len(apiSpec.GlobalConfig.DBAppConfOptions.Tags)
   105  	}
   106  
   107  	size += len(apiSpec.TagHeaders)
   108  
   109  	return size
   110  }
   111  
   112  func getSessionTags(session *user.SessionState) []string {
   113  	tags := make([]string, 0, len(session.Tags)+len(session.ApplyPolicies)+1)
   114  
   115  	// add policy IDs
   116  	for _, polID := range session.ApplyPolicies {
   117  		tags = append(tags, "pol-"+polID)
   118  	}
   119  
   120  	if session.GetMetaData() != nil {
   121  		if developerID, ok := session.GetMetaData()[keyDataDeveloperID].(string); ok {
   122  			tags = append(tags, "dev-"+developerID)
   123  		}
   124  	}
   125  
   126  	tags = append(tags, session.Tags...)
   127  
   128  	return tags
   129  }
   130  
   131  func (s *SuccessHandler) RecordHit(r *http.Request, timing Latency, code int, responseCopy *http.Response) {
   132  
   133  	if s.Spec.DoNotTrack {
   134  		return
   135  	}
   136  
   137  	ip := request.RealIP(r)
   138  	if s.Spec.GlobalConfig.StoreAnalytics(ip) {
   139  
   140  		t := time.Now()
   141  
   142  		// Track the key ID if it exists
   143  		token := ctxGetAuthToken(r)
   144  
   145  		// Track version data
   146  		version := s.Spec.getVersionFromRequest(r)
   147  		if version == "" {
   148  			version = "Non Versioned"
   149  		}
   150  
   151  		// If OAuth, we need to grab it from the session, which may or may not exist
   152  		oauthClientID := ""
   153  		var alias string
   154  		session := ctxGetSession(r)
   155  		tags := make([]string, 0, estimateTagsCapacity(session, s.Spec))
   156  		if session != nil {
   157  			oauthClientID = session.OauthClientID
   158  			tags = append(tags, getSessionTags(session)...)
   159  			alias = session.Alias
   160  		}
   161  
   162  		if len(s.Spec.TagHeaders) > 0 {
   163  			tags = tagHeaders(r, s.Spec.TagHeaders, tags)
   164  		}
   165  
   166  		rawRequest := ""
   167  		rawResponse := ""
   168  
   169  		if recordDetail(r, s.Spec) {
   170  			// Get the wire format representation
   171  			var wireFormatReq bytes.Buffer
   172  			r.Write(&wireFormatReq)
   173  			rawRequest = base64.StdEncoding.EncodeToString(wireFormatReq.Bytes())
   174  			// responseCopy, unlike requestCopy, can be nil
   175  			// here - if the response was cached in
   176  			// mw_redis_cache, RecordHit gets passed a nil
   177  			// response copy.
   178  			// TODO: pass a copy of the cached response in
   179  			// mw_redis_cache instead? is there a reason not
   180  			// to include that in the analytics?
   181  			if responseCopy != nil {
   182  				contents, err := ioutil.ReadAll(responseCopy.Body)
   183  				if err != nil {
   184  					log.Error("Couldn't read response body", err)
   185  				}
   186  
   187  				responseCopy.Body = respBodyReader(r, responseCopy)
   188  
   189  				// Get the wire format representation
   190  				var wireFormatRes bytes.Buffer
   191  				responseCopy.Write(&wireFormatRes)
   192  				responseCopy.Body = ioutil.NopCloser(bytes.NewBuffer(contents))
   193  				rawResponse = base64.StdEncoding.EncodeToString(wireFormatRes.Bytes())
   194  			}
   195  		}
   196  
   197  		trackEP := false
   198  		trackedPath := r.URL.Path
   199  		if p := ctxGetTrackedPath(r); p != "" && !ctxGetDoNotTrack(r) {
   200  			trackEP = true
   201  			trackedPath = p
   202  		}
   203  
   204  		host := r.URL.Host
   205  		if host == "" && s.Spec.target != nil {
   206  			host = s.Spec.target.Host
   207  		}
   208  
   209  		record := AnalyticsRecord{
   210  			r.Method,
   211  			host,
   212  			trackedPath,
   213  			r.URL.Path,
   214  			r.ContentLength,
   215  			r.Header.Get(headers.UserAgent),
   216  			t.Day(),
   217  			t.Month(),
   218  			t.Year(),
   219  			t.Hour(),
   220  			code,
   221  			token,
   222  			t,
   223  			version,
   224  			s.Spec.Name,
   225  			s.Spec.APIID,
   226  			s.Spec.OrgID,
   227  			oauthClientID,
   228  			timing.Total,
   229  			timing,
   230  			rawRequest,
   231  			rawResponse,
   232  			ip,
   233  			GeoData{},
   234  			NetworkStats{},
   235  			tags,
   236  			alias,
   237  			trackEP,
   238  			t,
   239  		}
   240  
   241  		if s.Spec.GlobalConfig.AnalyticsConfig.EnableGeoIP {
   242  			record.GetGeo(ip)
   243  		}
   244  
   245  		expiresAfter := s.Spec.ExpireAnalyticsAfter
   246  		if s.Spec.GlobalConfig.EnforceOrgDataAge {
   247  			orgExpireDataTime := s.OrgSessionExpiry(s.Spec.OrgID)
   248  
   249  			if orgExpireDataTime > 0 {
   250  				expiresAfter = orgExpireDataTime
   251  			}
   252  		}
   253  
   254  		record.SetExpiry(expiresAfter)
   255  
   256  		if s.Spec.GlobalConfig.AnalyticsConfig.NormaliseUrls.Enabled {
   257  			record.NormalisePath(&s.Spec.GlobalConfig)
   258  		}
   259  
   260  		analytics.RecordHit(&record)
   261  	}
   262  
   263  	// Report in health check
   264  	reportHealthValue(s.Spec, RequestLog, strconv.FormatInt(timing.Total, 10))
   265  
   266  	if memProfFile != nil {
   267  		pprof.WriteHeapProfile(memProfFile)
   268  	}
   269  }
   270  
   271  func recordDetail(r *http.Request, spec *APISpec) bool {
   272  	if spec.EnableDetailedRecording {
   273  		return true
   274  	}
   275  
   276  	session := ctxGetSession(r)
   277  	if session != nil {
   278  		if session.EnableDetailedRecording {
   279  			return true
   280  		}
   281  	}
   282  
   283  	// Are we even checking?
   284  	if !spec.GlobalConfig.EnforceOrgDataDetailLogging {
   285  		return spec.GlobalConfig.AnalyticsConfig.EnableDetailedRecording
   286  	}
   287  
   288  	// We are, so get session data
   289  	ses := r.Context().Value(ctx.OrgSessionContext)
   290  	if ses == nil {
   291  		// no session found, use global config
   292  		return spec.GlobalConfig.AnalyticsConfig.EnableDetailedRecording
   293  	}
   294  
   295  	// Session found
   296  	return ses.(user.SessionState).EnableDetailedRecording
   297  }
   298  
   299  // ServeHTTP will store the request details in the analytics store if necessary and proxy the request to it's
   300  // final destination, this is invoked by the ProxyHandler or right at the start of a request chain if the URL
   301  // Spec states the path is Ignored
   302  func (s *SuccessHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) *http.Response {
   303  	log.Debug("Started proxy")
   304  	defer s.Base().UpdateRequestSession(r)
   305  
   306  	versionDef := s.Spec.VersionDefinition
   307  	if !s.Spec.VersionData.NotVersioned && versionDef.Location == "url" && versionDef.StripPath {
   308  		part := s.Spec.getVersionFromRequest(r)
   309  
   310  		log.Info("Stripping version from url: ", part)
   311  
   312  		r.URL.Path = strings.Replace(r.URL.Path, part+"/", "", 1)
   313  		r.URL.RawPath = strings.Replace(r.URL.RawPath, part+"/", "", 1)
   314  	}
   315  
   316  	// Make sure we get the correct target URL
   317  	if s.Spec.Proxy.StripListenPath {
   318  		log.Debug("Stripping: ", s.Spec.Proxy.ListenPath)
   319  		r.URL.Path = s.Spec.StripListenPath(r, r.URL.Path)
   320  		r.URL.RawPath = s.Spec.StripListenPath(r, r.URL.RawPath)
   321  		log.Debug("Upstream Path is: ", r.URL.Path)
   322  	}
   323  
   324  	addVersionHeader(w, r, s.Spec.GlobalConfig)
   325  
   326  	t1 := time.Now()
   327  	resp := s.Proxy.ServeHTTP(w, r)
   328  
   329  	millisec := DurationToMillisecond(time.Since(t1))
   330  	log.Debug("Upstream request took (ms): ", millisec)
   331  
   332  	if resp.Response != nil {
   333  		latency := Latency{
   334  			Total:    int64(millisec),
   335  			Upstream: int64(DurationToMillisecond(resp.UpstreamLatency)),
   336  		}
   337  		s.RecordHit(r, latency, resp.Response.StatusCode, resp.Response)
   338  	}
   339  	log.Debug("Done proxy")
   340  	return nil
   341  }
   342  
   343  // ServeHTTPWithCache will store the request details in the analytics store if necessary and proxy the request to it's
   344  // final destination, this is invoked by the ProxyHandler or right at the start of a request chain if the URL
   345  // Spec states the path is Ignored Itwill also return a response object for the cache
   346  func (s *SuccessHandler) ServeHTTPWithCache(w http.ResponseWriter, r *http.Request) ProxyResponse {
   347  
   348  	versionDef := s.Spec.VersionDefinition
   349  	if !s.Spec.VersionData.NotVersioned && versionDef.Location == "url" && versionDef.StripPath {
   350  		part := s.Spec.getVersionFromRequest(r)
   351  
   352  		log.Info("Stripping version from url: ", part)
   353  
   354  		r.URL.Path = strings.Replace(r.URL.Path, part+"/", "", 1)
   355  		r.URL.RawPath = strings.Replace(r.URL.RawPath, part+"/", "", 1)
   356  	}
   357  
   358  	// Make sure we get the correct target URL
   359  	if s.Spec.Proxy.StripListenPath {
   360  		r.URL.Path = s.Spec.StripListenPath(r, r.URL.Path)
   361  		r.URL.RawPath = s.Spec.StripListenPath(r, r.URL.RawPath)
   362  	}
   363  
   364  	t1 := time.Now()
   365  	inRes := s.Proxy.ServeHTTPForCache(w, r)
   366  	millisec := DurationToMillisecond(time.Since(t1))
   367  
   368  	addVersionHeader(w, r, s.Spec.GlobalConfig)
   369  
   370  	log.Debug("Upstream request took (ms): ", millisec)
   371  
   372  	if inRes.Response != nil {
   373  		latency := Latency{
   374  			Total:    int64(millisec),
   375  			Upstream: int64(DurationToMillisecond(inRes.UpstreamLatency)),
   376  		}
   377  		s.RecordHit(r, latency, inRes.Response.StatusCode, inRes.Response)
   378  	}
   379  
   380  	return inRes
   381  }