github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/reverse_proxy.go (about)

     1  // Copyright 2011 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  // Fork of Go's net/http/httputil/reverseproxy.go with multiple changes,
     6  // including:
     7  //
     8  // * caching
     9  // * load balancing
    10  // * service discovery
    11  
    12  package gateway
    13  
    14  import (
    15  	"bytes"
    16  	"context"
    17  	"crypto/tls"
    18  	"crypto/x509"
    19  	"errors"
    20  	"fmt"
    21  	"io"
    22  	"io/ioutil"
    23  	"net"
    24  	"net/http"
    25  	"net/url"
    26  	"strconv"
    27  	"strings"
    28  	"sync"
    29  	"time"
    30  
    31  	"github.com/TykTechnologies/tyk/apidef"
    32  	"github.com/TykTechnologies/tyk/config"
    33  	"github.com/TykTechnologies/tyk/ctx"
    34  	"github.com/TykTechnologies/tyk/headers"
    35  	"github.com/TykTechnologies/tyk/regexp"
    36  	"github.com/TykTechnologies/tyk/trace"
    37  	"github.com/TykTechnologies/tyk/user"
    38  	opentracing "github.com/opentracing/opentracing-go"
    39  	"github.com/opentracing/opentracing-go/ext"
    40  	cache "github.com/pmylund/go-cache"
    41  	"github.com/sirupsen/logrus"
    42  	"golang.org/x/net/http2"
    43  )
    44  
    45  const defaultUserAgent = "Tyk/" + VERSION
    46  
    47  var corsHeaders = []string{
    48  	"Access-Control-Allow-Origin",
    49  	"Access-Control-Expose-Headers",
    50  	"Access-Control-Max-Age",
    51  	"Access-Control-Allow-Credentials",
    52  	"Access-Control-Allow-Methods",
    53  	"Access-Control-Allow-Headers"}
    54  
    55  var ServiceCache *cache.Cache
    56  var sdMu sync.RWMutex
    57  
    58  func urlFromService(spec *APISpec) (*apidef.HostList, error) {
    59  
    60  	doCacheRefresh := func() (*apidef.HostList, error) {
    61  		log.Debug("--> Refreshing")
    62  		spec.ServiceRefreshInProgress = true
    63  		defer func() { spec.ServiceRefreshInProgress = false }()
    64  		sd := ServiceDiscovery{}
    65  		sd.Init(&spec.Proxy.ServiceDiscovery)
    66  		data, err := sd.Target(spec.Proxy.ServiceDiscovery.QueryEndpoint)
    67  		if err != nil {
    68  			return nil, err
    69  		}
    70  		sdMu.Lock()
    71  		spec.HasRun = true
    72  		sdMu.Unlock()
    73  		// Set the cached value
    74  		if data.Len() == 0 {
    75  			log.Warning("[PROXY][SD] Service Discovery returned empty host list! Returning last good set.")
    76  
    77  			if spec.LastGoodHostList == nil {
    78  				log.Warning("[PROXY][SD] Last good host list is nil, returning empty set.")
    79  				spec.LastGoodHostList = apidef.NewHostList()
    80  			}
    81  
    82  			return spec.LastGoodHostList, nil
    83  		}
    84  
    85  		ServiceCache.Set(spec.APIID, data, cache.DefaultExpiration)
    86  		// Stash it too
    87  		spec.LastGoodHostList = data
    88  		return data, nil
    89  	}
    90  	sdMu.RLock()
    91  	hasRun := spec.HasRun
    92  	sdMu.RUnlock()
    93  	// First time? Refresh the cache and return that
    94  	if !hasRun {
    95  		log.Debug("First run! Setting cache")
    96  		return doCacheRefresh()
    97  	}
    98  
    99  	// Not first run - check the cache
   100  	cachedServiceData, found := ServiceCache.Get(spec.APIID)
   101  	if !found {
   102  		if spec.ServiceRefreshInProgress {
   103  			// Are we already refreshing the cache? skip and return last good conf
   104  			log.Debug("Cache expired! But service refresh in progress")
   105  			return spec.LastGoodHostList, nil
   106  		}
   107  		// Refresh the spec
   108  		log.Debug("Cache expired! Refreshing...")
   109  		return doCacheRefresh()
   110  	}
   111  
   112  	log.Debug("Returning from cache.")
   113  	return cachedServiceData.(*apidef.HostList), nil
   114  }
   115  
   116  // httpScheme matches http://* and https://*, case insensitive
   117  var httpScheme = regexp.MustCompile(`^(?i)https?://`)
   118  
   119  func EnsureTransport(host, protocol string) string {
   120  	if protocol == "" {
   121  		for _, v := range []string{"http://", "https://"} {
   122  			if strings.HasPrefix(host, v) {
   123  				return host
   124  			}
   125  		}
   126  		return "http://" + host
   127  	}
   128  	prefix := protocol + "://"
   129  	if strings.HasPrefix(host, prefix) {
   130  		return host
   131  	}
   132  	return prefix + host
   133  }
   134  
   135  func nextTarget(targetData *apidef.HostList, spec *APISpec) (string, error) {
   136  	if spec.Proxy.EnableLoadBalancing {
   137  		log.Debug("[PROXY] [LOAD BALANCING] Load balancer enabled, getting upstream target")
   138  		// Use a HostList
   139  		startPos := spec.RoundRobin.WithLen(targetData.Len())
   140  		pos := startPos
   141  		for {
   142  			gotHost, err := targetData.GetIndex(pos)
   143  			if err != nil {
   144  				return "", err
   145  			}
   146  
   147  			host := EnsureTransport(gotHost, spec.Protocol)
   148  
   149  			if !spec.Proxy.CheckHostAgainstUptimeTests {
   150  				return host, nil // we don't care if it's up
   151  			}
   152  			// As checked by HostCheckerManager.AmIPolling
   153  			if GlobalHostChecker.store == nil {
   154  				return host, nil
   155  			}
   156  			if !GlobalHostChecker.HostDown(host) {
   157  				return host, nil // we do care and it's up
   158  			}
   159  			// if the host is down, keep trying all the rest
   160  			// in order from where we started.
   161  			if pos = (pos + 1) % targetData.Len(); pos == startPos {
   162  				return "", fmt.Errorf("all hosts are down, uptime tests are failing")
   163  			}
   164  		}
   165  
   166  	}
   167  	// Use standard target - might still be service data
   168  	log.Debug("TARGET DATA:", targetData)
   169  
   170  	gotHost, err := targetData.GetIndex(0)
   171  	if err != nil {
   172  		return "", err
   173  	}
   174  	return EnsureTransport(gotHost, spec.Protocol), nil
   175  }
   176  
   177  var (
   178  	onceStartAllHostsDown sync.Once
   179  
   180  	allHostsDownURL string
   181  )
   182  
   183  // TykNewSingleHostReverseProxy returns a new ReverseProxy that rewrites
   184  // URLs to the scheme, host, and base path provided in target. If the
   185  // target's path is "/base" and the incoming request was for "/dir",
   186  // the target request will be for /base/dir. This version modifies the
   187  // stdlib version by also setting the host to the target, this allows
   188  // us to work with heroku and other such providers
   189  func TykNewSingleHostReverseProxy(target *url.URL, spec *APISpec, logger *logrus.Entry) *ReverseProxy {
   190  	onceStartAllHostsDown.Do(func() {
   191  		handler := func(w http.ResponseWriter, r *http.Request) {
   192  			http.Error(w, "all hosts are down", http.StatusServiceUnavailable)
   193  		}
   194  		listener, err := net.Listen("tcp", "127.0.0.1:0")
   195  		if err != nil {
   196  			panic(err)
   197  		}
   198  		server := &http.Server{
   199  			Handler:        http.HandlerFunc(handler),
   200  			ReadTimeout:    1 * time.Second,
   201  			WriteTimeout:   1 * time.Second,
   202  			MaxHeaderBytes: 1 << 20,
   203  		}
   204  		allHostsDownURL = "http://" + listener.Addr().String()
   205  		go func() {
   206  			panic(server.Serve(listener))
   207  		}()
   208  	})
   209  	if spec.Proxy.ServiceDiscovery.UseDiscoveryService {
   210  		log.Debug("[PROXY] Service discovery enabled")
   211  		if ServiceCache == nil {
   212  			log.Debug("[PROXY] Service cache initialising")
   213  			expiry := 120
   214  			if spec.Proxy.ServiceDiscovery.CacheTimeout > 0 {
   215  				expiry = int(spec.Proxy.ServiceDiscovery.CacheTimeout)
   216  			} else if spec.GlobalConfig.ServiceDiscovery.DefaultCacheTimeout > 0 {
   217  				expiry = spec.GlobalConfig.ServiceDiscovery.DefaultCacheTimeout
   218  			}
   219  			ServiceCache = cache.New(time.Duration(expiry)*time.Second, 15*time.Second)
   220  		}
   221  	}
   222  
   223  	targetQuery := target.RawQuery
   224  	director := func(req *http.Request) {
   225  		hostList := spec.Proxy.StructuredTargetList
   226  		switch {
   227  		case spec.Proxy.ServiceDiscovery.UseDiscoveryService:
   228  			var err error
   229  			hostList, err = urlFromService(spec)
   230  			if err != nil {
   231  				log.Error("[PROXY] [SERVICE DISCOVERY] Failed target lookup: ", err)
   232  				break
   233  			}
   234  			fallthrough // implies load balancing, with replaced host list
   235  		case spec.Proxy.EnableLoadBalancing:
   236  			host, err := nextTarget(hostList, spec)
   237  			if err != nil {
   238  				log.Error("[PROXY] [LOAD BALANCING] ", err)
   239  				host = allHostsDownURL
   240  			}
   241  			lbRemote, err := url.Parse(host)
   242  			if err != nil {
   243  				log.Error("[PROXY] [LOAD BALANCING] Couldn't parse target URL:", err)
   244  			} else {
   245  				// Only replace target if everything is OK
   246  				target = lbRemote
   247  				targetQuery = target.RawQuery
   248  			}
   249  		}
   250  
   251  		targetToUse := target
   252  
   253  		if spec.URLRewriteEnabled && req.Context().Value(ctx.RetainHost) == true {
   254  			log.Debug("Detected host rewrite, overriding target")
   255  			tmpTarget, err := url.Parse(req.URL.String())
   256  			if err != nil {
   257  				log.Error("Failed to parse URL! Err: ", err)
   258  			} else {
   259  				// Specifically override with a URL rewrite
   260  				targetToUse = tmpTarget
   261  			}
   262  		}
   263  
   264  		// No override, and no load balancing? Use the existing target
   265  
   266  		// if this is false, there was an url rewrite, thus we
   267  		// don't want to do anything to the path - req.URL is
   268  		// already final.
   269  		if targetToUse == target {
   270  			req.URL.Scheme = targetToUse.Scheme
   271  			req.URL.Host = targetToUse.Host
   272  			req.URL.Path = singleJoiningSlash(targetToUse.Path, req.URL.Path, spec.Proxy.DisableStripSlash)
   273  			if req.URL.RawPath != "" {
   274  				req.URL.RawPath = singleJoiningSlash(targetToUse.Path, req.URL.RawPath, spec.Proxy.DisableStripSlash)
   275  			}
   276  		}
   277  
   278  		if !spec.Proxy.PreserveHostHeader {
   279  			req.Host = targetToUse.Host
   280  		}
   281  
   282  		if targetQuery == "" || req.URL.RawQuery == "" {
   283  			req.URL.RawQuery = targetQuery + req.URL.RawQuery
   284  		} else {
   285  			req.URL.RawQuery = targetQuery + "&" + req.URL.RawQuery
   286  		}
   287  		if _, ok := req.Header[headers.UserAgent]; !ok {
   288  			// Set Tyk's own default user agent. Without
   289  			// this line, we would get the net/http default.
   290  			req.Header.Set(headers.UserAgent, defaultUserAgent)
   291  		}
   292  
   293  		if spec.GlobalConfig.HttpServerOptions.SkipTargetPathEscaping {
   294  			// force RequestURI to skip escaping if API's proxy is set for this
   295  			// if we set opaque here it will force URL.RequestURI to skip escaping
   296  			if req.URL.RawPath != "" {
   297  				req.URL.Opaque = req.URL.RawPath
   298  			}
   299  		} else if req.URL.RawPath == req.URL.Path {
   300  			// this should force URL to do escaping
   301  			req.URL.RawPath = ""
   302  		}
   303  
   304  		switch req.URL.Scheme {
   305  		case "ws":
   306  			req.URL.Scheme = "http"
   307  		case "wss":
   308  			req.URL.Scheme = "https"
   309  		}
   310  	}
   311  
   312  	if logger == nil {
   313  		logger = logrus.NewEntry(log)
   314  	}
   315  
   316  	logger = logger.WithField("mw", "ReverseProxy")
   317  
   318  	proxy := &ReverseProxy{
   319  		Director:      director,
   320  		TykAPISpec:    spec,
   321  		FlushInterval: time.Duration(spec.GlobalConfig.HttpServerOptions.FlushInterval) * time.Millisecond,
   322  		logger:        logger,
   323  		sp: sync.Pool{
   324  			New: func() interface{} {
   325  				buffer := make([]byte, 32*1024)
   326  				return &buffer
   327  			},
   328  		},
   329  	}
   330  	proxy.ErrorHandler.BaseMiddleware = BaseMiddleware{Spec: spec, Proxy: proxy}
   331  	return proxy
   332  }
   333  
   334  // ReverseProxy is an HTTP Handler that takes an incoming request and
   335  // sends it to another server, proxying the response back to the
   336  // client.
   337  type ReverseProxy struct {
   338  	// Director must be a function which modifies
   339  	// the request into a new request to be sent
   340  	// using Transport. Its response is then copied
   341  	// back to the original client unmodified.
   342  	Director func(*http.Request)
   343  
   344  	// The transport used to perform proxy requests.
   345  	// If nil, http.DefaultTransport is used.
   346  	Transport http.RoundTripper
   347  
   348  	// FlushInterval specifies the flush interval
   349  	// to flush to the client while copying the
   350  	// response body.
   351  	// If zero, no periodic flushing is done.
   352  	FlushInterval time.Duration
   353  
   354  	// TLSClientConfig specifies the TLS configuration to use for 'wss'.
   355  	// If nil, the default configuration is used.
   356  	TLSClientConfig *tls.Config
   357  
   358  	TykAPISpec   *APISpec
   359  	ErrorHandler ErrorHandler
   360  
   361  	logger *logrus.Entry
   362  	sp     sync.Pool
   363  }
   364  
   365  func defaultTransport(dialerTimeout float64) *http.Transport {
   366  	timeout := 30.0
   367  	if dialerTimeout > 0 {
   368  		log.Debug("Setting timeout for outbound request to: ", dialerTimeout)
   369  		timeout = dialerTimeout
   370  	}
   371  
   372  	dialer := &net.Dialer{
   373  		Timeout:   time.Duration(float64(timeout) * float64(time.Second)),
   374  		KeepAlive: 30 * time.Second,
   375  		DualStack: true,
   376  	}
   377  	dialContextFunc := dialer.DialContext
   378  	if dnsCacheManager.IsCacheEnabled() {
   379  		dialContextFunc = dnsCacheManager.WrapDialer(dialer)
   380  	}
   381  
   382  	return &http.Transport{
   383  		DialContext:           dialContextFunc,
   384  		MaxIdleConns:          config.Global().MaxIdleConns,
   385  		MaxIdleConnsPerHost:   config.Global().MaxIdleConnsPerHost, // default is 100
   386  		ResponseHeaderTimeout: time.Duration(dialerTimeout) * time.Second,
   387  		TLSHandshakeTimeout:   10 * time.Second,
   388  	}
   389  }
   390  
   391  func singleJoiningSlash(a, b string, disableStripSlash bool) string {
   392  	if disableStripSlash && len(b) == 0 {
   393  		return a
   394  	}
   395  	a = strings.TrimRight(a, "/")
   396  	b = strings.TrimLeft(b, "/")
   397  	if len(b) > 0 {
   398  		return a + "/" + b
   399  	}
   400  	return a
   401  }
   402  
   403  func removeDuplicateCORSHeader(dst, src http.Header) {
   404  	for _, v := range corsHeaders {
   405  		keyName := http.CanonicalHeaderKey(v)
   406  		if val := dst.Get(keyName); val != "" {
   407  			src.Del(keyName)
   408  		}
   409  	}
   410  }
   411  
   412  func copyHeader(dst, src http.Header) {
   413  
   414  	removeDuplicateCORSHeader(dst, src)
   415  
   416  	for k, vv := range src {
   417  		for _, v := range vv {
   418  			dst.Add(k, v)
   419  		}
   420  	}
   421  }
   422  
   423  func cloneHeader(h http.Header) http.Header {
   424  	h2 := make(http.Header, len(h))
   425  	for k, vv := range h {
   426  		vv2 := make([]string, len(vv))
   427  		copy(vv2, vv)
   428  		h2[k] = vv2
   429  	}
   430  	return h2
   431  }
   432  
   433  // Hop-by-hop headers. These are removed when sent to the backend.
   434  // http://www.w3.org/Protocols/rfc2616/rfc2616-sec13.html
   435  var hopHeaders = []string{
   436  	"Connection",
   437  	"Proxy-Connection", // non-standard but still sent by libcurl and rejected by e.g. google
   438  	"Keep-Alive",
   439  	"Proxy-Authenticate",
   440  	"Proxy-Authorization",
   441  	"Te",      // canonicalized version of "TE"
   442  	"Trailer", // not Trailers per URL above; http://www.rfc-editor.org/errata_search.php?eid=4522
   443  	"Transfer-Encoding",
   444  	"Upgrade",
   445  }
   446  
   447  func (p *ReverseProxy) ServeHTTP(rw http.ResponseWriter, req *http.Request) ProxyResponse {
   448  	startTime := time.Now()
   449  	p.logger.WithField("ts", startTime.UnixNano()).Debug("Started")
   450  	resp := p.WrappedServeHTTP(rw, req, recordDetail(req, p.TykAPISpec))
   451  
   452  	finishTime := time.Since(startTime)
   453  	p.logger.WithField("ns", finishTime.Nanoseconds()).Debug("Finished")
   454  
   455  	// make response body to be nopCloser and re-readable before serve it through chain of middlewares
   456  	nopCloseResponseBody(resp.Response)
   457  
   458  	return resp
   459  }
   460  
   461  func (p *ReverseProxy) ServeHTTPForCache(rw http.ResponseWriter, req *http.Request) ProxyResponse {
   462  	startTime := time.Now()
   463  	p.logger.WithField("ts", startTime.UnixNano()).Debug("Started")
   464  
   465  	resp := p.WrappedServeHTTP(rw, req, true)
   466  	nopCloseResponseBody(resp.Response)
   467  	finishTime := time.Since(startTime)
   468  	p.logger.WithField("ns", finishTime.Nanoseconds()).Debug("Finished")
   469  
   470  	return resp
   471  }
   472  
   473  func (p *ReverseProxy) CheckHardTimeoutEnforced(spec *APISpec, req *http.Request) (bool, float64) {
   474  	if !spec.EnforcedTimeoutEnabled {
   475  		return false, spec.GlobalConfig.ProxyDefaultTimeout
   476  	}
   477  
   478  	_, versionPaths, _, _ := spec.Version(req)
   479  	found, meta := spec.CheckSpecMatchesStatus(req, versionPaths, HardTimeout)
   480  	if found {
   481  		intMeta := meta.(*int)
   482  		p.logger.Debug("HARD TIMEOUT ENFORCED: ", *intMeta)
   483  		return true, float64(*intMeta)
   484  	}
   485  
   486  	return false, spec.GlobalConfig.ProxyDefaultTimeout
   487  }
   488  
   489  func (p *ReverseProxy) CheckHeaderInRemoveList(hdr string, spec *APISpec, req *http.Request) bool {
   490  	vInfo, versionPaths, _, _ := spec.Version(req)
   491  	for _, gdKey := range vInfo.GlobalHeadersRemove {
   492  		if strings.ToLower(gdKey) == strings.ToLower(hdr) {
   493  			return true
   494  		}
   495  	}
   496  
   497  	// Check path config
   498  	if found, meta := spec.CheckSpecMatchesStatus(req, versionPaths, HeaderInjected); found {
   499  		hmeta := meta.(*apidef.HeaderInjectionMeta)
   500  		for _, gdKey := range hmeta.DeleteHeaders {
   501  			if strings.ToLower(gdKey) == strings.ToLower(hdr) {
   502  				return true
   503  			}
   504  		}
   505  	}
   506  
   507  	return false
   508  }
   509  
   510  func (p *ReverseProxy) CheckCircuitBreakerEnforced(spec *APISpec, req *http.Request) (bool, *ExtendedCircuitBreakerMeta) {
   511  	if !spec.CircuitBreakerEnabled {
   512  		return false, nil
   513  	}
   514  
   515  	_, versionPaths, _, _ := spec.Version(req)
   516  	found, meta := spec.CheckSpecMatchesStatus(req, versionPaths, CircuitBreaker)
   517  	if found {
   518  		exMeta := meta.(*ExtendedCircuitBreakerMeta)
   519  		p.logger.Debug("CB Enforced for path: ", *exMeta)
   520  		return true, exMeta
   521  	}
   522  
   523  	return false, nil
   524  }
   525  
   526  func proxyFromAPI(api *APISpec) func(*http.Request) (*url.URL, error) {
   527  	return func(req *http.Request) (*url.URL, error) {
   528  		if api != nil && api.Proxy.Transport.ProxyURL != "" {
   529  			return url.Parse(api.Proxy.Transport.ProxyURL)
   530  		}
   531  		return http.ProxyFromEnvironment(req)
   532  	}
   533  }
   534  
   535  func tlsClientConfig(s *APISpec) *tls.Config {
   536  	config := &tls.Config{}
   537  
   538  	if s.GlobalConfig.ProxySSLInsecureSkipVerify {
   539  		config.InsecureSkipVerify = true
   540  	}
   541  
   542  	if s.Proxy.Transport.SSLInsecureSkipVerify {
   543  		config.InsecureSkipVerify = true
   544  	}
   545  
   546  	if s.GlobalConfig.ProxySSLMinVersion > 0 {
   547  		config.MinVersion = s.GlobalConfig.ProxySSLMinVersion
   548  	}
   549  
   550  	if s.Proxy.Transport.SSLMinVersion > 0 {
   551  		config.MinVersion = s.Proxy.Transport.SSLMinVersion
   552  	}
   553  
   554  	if len(s.GlobalConfig.ProxySSLCipherSuites) > 0 {
   555  		config.CipherSuites = getCipherAliases(s.GlobalConfig.ProxySSLCipherSuites)
   556  	}
   557  
   558  	if len(s.Proxy.Transport.SSLCipherSuites) > 0 {
   559  		config.CipherSuites = getCipherAliases(s.Proxy.Transport.SSLCipherSuites)
   560  	}
   561  
   562  	if !s.GlobalConfig.ProxySSLDisableRenegotiation {
   563  		config.Renegotiation = tls.RenegotiateFreelyAsClient
   564  	}
   565  
   566  	return config
   567  }
   568  
   569  func httpTransport(timeOut float64, rw http.ResponseWriter, req *http.Request, p *ReverseProxy) http.RoundTripper {
   570  	transport := defaultTransport(timeOut) // modifies a newly created transport
   571  	transport.TLSClientConfig = &tls.Config{}
   572  	transport.Proxy = proxyFromAPI(p.TykAPISpec)
   573  
   574  	if config.Global().ProxySSLInsecureSkipVerify {
   575  		transport.TLSClientConfig.InsecureSkipVerify = true
   576  	}
   577  
   578  	if p.TykAPISpec.Proxy.Transport.SSLInsecureSkipVerify {
   579  		transport.TLSClientConfig.InsecureSkipVerify = true
   580  	}
   581  
   582  	// When request routed through the proxy `DialTLS` is not used, and only VerifyPeerCertificate is supported
   583  	// The reason behind two separate checks is that `DialTLS` supports specifying public keys per hostname, and `VerifyPeerCertificate` only global ones, e.g. `*`
   584  	if proxyURL, _ := transport.Proxy(req); proxyURL != nil {
   585  		p.logger.Debug("Detected proxy: " + proxyURL.String())
   586  		transport.TLSClientConfig.VerifyPeerCertificate = verifyPeerCertificatePinnedCheck(p.TykAPISpec, transport.TLSClientConfig)
   587  
   588  		if transport.TLSClientConfig.VerifyPeerCertificate != nil {
   589  			p.logger.Debug("Certificate pinning check is enabled")
   590  		}
   591  	} else {
   592  		transport.DialTLS = customDialTLSCheck(p.TykAPISpec, transport.TLSClientConfig)
   593  	}
   594  
   595  	if p.TykAPISpec.GlobalConfig.ProxySSLMinVersion > 0 {
   596  		transport.TLSClientConfig.MinVersion = p.TykAPISpec.GlobalConfig.ProxySSLMinVersion
   597  	}
   598  
   599  	if p.TykAPISpec.Proxy.Transport.SSLMinVersion > 0 {
   600  		transport.TLSClientConfig.MinVersion = p.TykAPISpec.Proxy.Transport.SSLMinVersion
   601  	}
   602  
   603  	if len(p.TykAPISpec.GlobalConfig.ProxySSLCipherSuites) > 0 {
   604  		transport.TLSClientConfig.CipherSuites = getCipherAliases(p.TykAPISpec.GlobalConfig.ProxySSLCipherSuites)
   605  	}
   606  
   607  	if len(p.TykAPISpec.Proxy.Transport.SSLCipherSuites) > 0 {
   608  		transport.TLSClientConfig.CipherSuites = getCipherAliases(p.TykAPISpec.Proxy.Transport.SSLCipherSuites)
   609  	}
   610  
   611  	if !config.Global().ProxySSLDisableRenegotiation {
   612  		transport.TLSClientConfig.Renegotiation = tls.RenegotiateFreelyAsClient
   613  	}
   614  
   615  	transport.DisableKeepAlives = p.TykAPISpec.GlobalConfig.ProxyCloseConnections
   616  
   617  	if config.Global().ProxyEnableHttp2 {
   618  		http2.ConfigureTransport(transport)
   619  	}
   620  
   621  	return transport
   622  }
   623  
   624  func (p *ReverseProxy) setCommonNameVerifyPeerCertificate(tlsConfig *tls.Config, hostName string) {
   625  	tlsConfig.InsecureSkipVerify = true
   626  
   627  	// if verifyPeerCertificate was set previously, make sure it is also executed
   628  	prevFunc := tlsConfig.VerifyPeerCertificate
   629  	tlsConfig.VerifyPeerCertificate = func(rawCerts [][]byte, verifiedChains [][]*x509.Certificate) error {
   630  		if prevFunc != nil {
   631  			err := prevFunc(rawCerts, verifiedChains)
   632  			if err != nil {
   633  				p.logger.Error("Failed to verify server certificate: " + err.Error())
   634  				return err
   635  			}
   636  		}
   637  
   638  		// followed https://github.com/golang/go/issues/21971#issuecomment-332693931
   639  		certs := make([]*x509.Certificate, len(rawCerts))
   640  		for i, asn1Data := range rawCerts {
   641  			cert, err := x509.ParseCertificate(asn1Data)
   642  			if err != nil {
   643  				return errors.New("failed to parse certificate from server: " + err.Error())
   644  			}
   645  			certs[i] = cert
   646  		}
   647  
   648  		if !p.TykAPISpec.Proxy.Transport.SSLInsecureSkipVerify && !config.Global().ProxySSLInsecureSkipVerify {
   649  			opts := x509.VerifyOptions{
   650  				Roots:         tlsConfig.RootCAs,
   651  				CurrentTime:   time.Now(),
   652  				DNSName:       "", // <- skip hostname verification
   653  				Intermediates: x509.NewCertPool(),
   654  			}
   655  
   656  			for i, cert := range certs {
   657  				if i == 0 {
   658  					continue
   659  				}
   660  				opts.Intermediates.AddCert(cert)
   661  			}
   662  			_, err := certs[0].Verify(opts)
   663  			if err != nil {
   664  				p.logger.Error("Failed to verify server certificate: " + err.Error())
   665  				return err
   666  			}
   667  		}
   668  
   669  		return validateCommonName(hostName, certs[0])
   670  	}
   671  }
   672  
   673  func (p *ReverseProxy) WrappedServeHTTP(rw http.ResponseWriter, req *http.Request, withCache bool) ProxyResponse {
   674  	if trace.IsEnabled() {
   675  		span, ctx := trace.Span(req.Context(), req.URL.Path)
   676  		defer span.Finish()
   677  		ext.SpanKindRPCClient.Set(span)
   678  		req = req.WithContext(ctx)
   679  	}
   680  	var roundTripper http.RoundTripper
   681  
   682  	p.TykAPISpec.Lock()
   683  
   684  	// create HTTP transport
   685  	createTransport := p.TykAPISpec.HTTPTransport == nil
   686  
   687  	// Check if timeouts are set for this endpoint
   688  	if !createTransport && config.Global().MaxConnTime != 0 {
   689  		createTransport = time.Since(p.TykAPISpec.HTTPTransportCreated) > time.Duration(config.Global().MaxConnTime)*time.Second
   690  	}
   691  
   692  	if createTransport {
   693  		_, timeout := p.CheckHardTimeoutEnforced(p.TykAPISpec, req)
   694  		p.TykAPISpec.HTTPTransport = httpTransport(timeout, rw, req, p)
   695  		p.TykAPISpec.HTTPTransportCreated = time.Now()
   696  
   697  		p.logger.Debug("Creating new transport")
   698  	}
   699  
   700  	roundTripper = p.TykAPISpec.HTTPTransport
   701  
   702  	p.TykAPISpec.Unlock()
   703  
   704  	reqCtx := req.Context()
   705  	if cn, ok := rw.(http.CloseNotifier); ok {
   706  		var cancel context.CancelFunc
   707  		reqCtx, cancel = context.WithCancel(reqCtx)
   708  		defer cancel()
   709  		notifyChan := cn.CloseNotify()
   710  		go func() {
   711  			select {
   712  			case <-notifyChan:
   713  				cancel()
   714  			case <-reqCtx.Done():
   715  			}
   716  		}()
   717  	}
   718  
   719  	// Do this before we make a shallow copy
   720  	session := ctxGetSession(req)
   721  
   722  	outreq := new(http.Request)
   723  	logreq := new(http.Request)
   724  
   725  	*outreq = *req // includes shallow copies of maps, but okay
   726  	*logreq = *req
   727  	// remove context data from the copies
   728  	setContext(outreq, context.Background())
   729  	setContext(logreq, context.Background())
   730  
   731  	p.logger.Debug("Upstream request URL: ", req.URL)
   732  
   733  	// We need to double set the context for the outbound request to reprocess the target
   734  	if p.TykAPISpec.URLRewriteEnabled && req.Context().Value(ctx.RetainHost) == true {
   735  		p.logger.Debug("Detected host rewrite, notifying director")
   736  		setCtxValue(outreq, ctx.RetainHost, true)
   737  	}
   738  
   739  	if req.ContentLength == 0 {
   740  		outreq.Body = nil // Issue 16036: nil Body for http.Transport retries
   741  	}
   742  	outreq = outreq.WithContext(reqCtx)
   743  
   744  	outreq.Header = cloneHeader(req.Header)
   745  	if trace.IsEnabled() {
   746  		span := opentracing.SpanFromContext(req.Context())
   747  		trace.Inject(p.TykAPISpec.Name, span, outreq.Header)
   748  	}
   749  	p.Director(outreq)
   750  	outreq.Close = false
   751  
   752  	p.logger.Debug("Outbound request URL: ", outreq.URL.String())
   753  
   754  	outReqUpgrade, reqUpType := IsUpgrade(req)
   755  
   756  	// See RFC 2616, section 14.10.
   757  	if c := outreq.Header.Get("Connection"); c != "" {
   758  		for _, f := range strings.Split(c, ",") {
   759  			if f = strings.TrimSpace(f); f != "" {
   760  				outreq.Header.Del(f)
   761  			}
   762  		}
   763  	}
   764  	// Remove other hop-by-hop headers to the backend. Especially
   765  	// important is "Connection" because we want a persistent
   766  	// connection, regardless of what the client sent to us.
   767  	for _, h := range hopHeaders {
   768  		hv := outreq.Header.Get(h)
   769  		if hv == "" {
   770  			continue
   771  		}
   772  		if h == "Te" && hv == "trailers" {
   773  			continue
   774  		}
   775  		outreq.Header.Del(h)
   776  		logreq.Header.Del(h)
   777  	}
   778  
   779  	if outReqUpgrade {
   780  		outreq.Header.Set("Connection", "Upgrade")
   781  		logreq.Header.Set("Connection", "Upgrade")
   782  		outreq.Header.Set("Upgrade", reqUpType)
   783  		logreq.Header.Set("Upgrade", reqUpType)
   784  	}
   785  
   786  	addrs := requestIPHops(req)
   787  	if !p.CheckHeaderInRemoveList(headers.XForwardFor, p.TykAPISpec, req) {
   788  		outreq.Header.Set(headers.XForwardFor, addrs)
   789  	}
   790  
   791  	// Circuit breaker
   792  	breakerEnforced, breakerConf := p.CheckCircuitBreakerEnforced(p.TykAPISpec, req)
   793  
   794  	// set up TLS certificates for upstream if needed
   795  	var tlsCertificates []tls.Certificate
   796  	if cert := getUpstreamCertificate(outreq.Host, p.TykAPISpec); cert != nil {
   797  		p.logger.Debug("Found upstream mutual TLS certificate")
   798  		tlsCertificates = []tls.Certificate{*cert}
   799  	}
   800  
   801  	p.TykAPISpec.Lock()
   802  	roundTripper.(*http.Transport).TLSClientConfig.Certificates = tlsCertificates
   803  	p.TykAPISpec.Unlock()
   804  
   805  	if p.TykAPISpec.Proxy.Transport.SSLForceCommonNameCheck || config.Global().SSLForceCommonNameCheck {
   806  		// if proxy is enabled, add CommonName verification in verifyPeerCertificate
   807  		// DialTLS is not executed if proxy is used
   808  		httpTransport := roundTripper.(*http.Transport)
   809  
   810  		p.logger.Debug("Using forced SSL CN check")
   811  
   812  		if proxyURL, _ := httpTransport.Proxy(req); proxyURL != nil {
   813  			p.logger.Debug("Detected proxy: " + proxyURL.String())
   814  			tlsConfig := httpTransport.TLSClientConfig
   815  			host, _, _ := net.SplitHostPort(outreq.Host)
   816  			p.setCommonNameVerifyPeerCertificate(tlsConfig, host)
   817  		}
   818  
   819  	}
   820  
   821  	// do request round trip
   822  	var res *http.Response
   823  	var err error
   824  	var upstreamLatency time.Duration
   825  	if breakerEnforced {
   826  		if !breakerConf.CB.Ready() {
   827  			p.logger.Debug("ON REQUEST: Circuit Breaker is in OPEN state")
   828  			p.ErrorHandler.HandleError(rw, logreq, "Service temporarily unavailable.", 503, true)
   829  			return ProxyResponse{}
   830  		}
   831  		p.logger.Debug("ON REQUEST: Circuit Breaker is in CLOSED or HALF-OPEN state")
   832  		begin := time.Now()
   833  		res, err = roundTripper.RoundTrip(outreq)
   834  		upstreamLatency = time.Since(begin)
   835  		if err != nil || res.StatusCode/100 == 5 {
   836  			breakerConf.CB.Fail()
   837  		} else {
   838  			breakerConf.CB.Success()
   839  		}
   840  	} else {
   841  		begin := time.Now()
   842  		res, err = roundTripper.RoundTrip(outreq)
   843  		upstreamLatency = time.Since(begin)
   844  	}
   845  
   846  	if err != nil {
   847  
   848  		token := ctxGetAuthToken(req)
   849  
   850  		var alias string
   851  		if session != nil {
   852  			alias = session.Alias
   853  		}
   854  
   855  		p.logger.WithFields(logrus.Fields{
   856  			"prefix":      "proxy",
   857  			"user_ip":     addrs,
   858  			"server_name": outreq.Host,
   859  			"user_id":     obfuscateKey(token),
   860  			"user_name":   alias,
   861  			"org_id":      p.TykAPISpec.OrgID,
   862  			"api_id":      p.TykAPISpec.APIID,
   863  		}).Error("http: proxy error: ", err)
   864  		if strings.Contains(err.Error(), "timeout awaiting response headers") {
   865  			p.ErrorHandler.HandleError(rw, logreq, "Upstream service reached hard timeout.", http.StatusGatewayTimeout, true)
   866  
   867  			if p.TykAPISpec.Proxy.ServiceDiscovery.UseDiscoveryService {
   868  				if ServiceCache != nil {
   869  					p.logger.Debug("[PROXY] [SERVICE DISCOVERY] Upstream host failed, refreshing host list")
   870  					ServiceCache.Delete(p.TykAPISpec.APIID)
   871  				}
   872  			}
   873  			return ProxyResponse{UpstreamLatency: upstreamLatency}
   874  		}
   875  
   876  		if strings.Contains(err.Error(), "context canceled") {
   877  			p.ErrorHandler.HandleError(rw, logreq, "Client closed request", 499, true)
   878  			return ProxyResponse{UpstreamLatency: upstreamLatency}
   879  		}
   880  
   881  		if strings.Contains(err.Error(), "no such host") {
   882  			p.ErrorHandler.HandleError(rw, logreq, "Upstream host lookup failed", http.StatusInternalServerError, true)
   883  			return ProxyResponse{UpstreamLatency: upstreamLatency}
   884  		}
   885  		p.ErrorHandler.HandleError(rw, logreq, "There was a problem proxying the request", http.StatusInternalServerError, true)
   886  		return ProxyResponse{UpstreamLatency: upstreamLatency}
   887  
   888  	}
   889  
   890  	upgrade, _ := IsUpgrade(req)
   891  	// Deal with 101 Switching Protocols responses: (WebSocket, h2c, etc)
   892  	if upgrade {
   893  		if err := p.handleUpgradeResponse(rw, outreq, res); err != nil {
   894  			p.ErrorHandler.HandleError(rw, logreq, err.Error(), http.StatusInternalServerError, true)
   895  			return ProxyResponse{UpstreamLatency: upstreamLatency}
   896  		}
   897  	}
   898  
   899  	ses := new(user.SessionState)
   900  	ses.Mutex = &sync.RWMutex{}
   901  	if session != nil {
   902  		ses = session
   903  	}
   904  
   905  	// Middleware chain handling here - very simple, but should do
   906  	// the trick. Chain can be empty, in which case this is a no-op.
   907  	if err := handleResponseChain(p.TykAPISpec.ResponseChain, rw, res, req, ses); err != nil {
   908  		p.logger.Error("Response chain failed! ", err)
   909  	}
   910  
   911  	inres := new(http.Response)
   912  	if withCache {
   913  		*inres = *res // includes shallow copies of maps, but okay
   914  
   915  		if !upgrade {
   916  			defer res.Body.Close()
   917  
   918  			// Buffer body data
   919  			var bodyBuffer bytes.Buffer
   920  			bodyBuffer2 := new(bytes.Buffer)
   921  
   922  			p.CopyResponse(&bodyBuffer, res.Body)
   923  			*bodyBuffer2 = bodyBuffer
   924  
   925  			// Create new ReadClosers so we can split output
   926  			res.Body = ioutil.NopCloser(&bodyBuffer)
   927  			inres.Body = ioutil.NopCloser(bodyBuffer2)
   928  		}
   929  	}
   930  
   931  	// We should at least copy the status code in
   932  	inres.StatusCode = res.StatusCode
   933  	inres.ContentLength = res.ContentLength
   934  	p.HandleResponse(rw, res, ses)
   935  	return ProxyResponse{UpstreamLatency: upstreamLatency, Response: inres}
   936  }
   937  
   938  func (p *ReverseProxy) HandleResponse(rw http.ResponseWriter, res *http.Response, ses *user.SessionState) error {
   939  
   940  	// Remove hop-by-hop headers listed in the
   941  	// "Connection" header of the response.
   942  	if c := res.Header.Get(headers.Connection); c != "" {
   943  		for _, f := range strings.Split(c, ",") {
   944  			if f = strings.TrimSpace(f); f != "" {
   945  				res.Header.Del(f)
   946  			}
   947  		}
   948  	}
   949  
   950  	for _, h := range hopHeaders {
   951  		res.Header.Del(h)
   952  	}
   953  	defer res.Body.Close()
   954  
   955  	// Close connections
   956  	if config.Global().CloseConnections {
   957  		res.Header.Set(headers.Connection, "close")
   958  	}
   959  
   960  	// Add resource headers
   961  	if ses != nil {
   962  		// We have found a session, lets report back
   963  		quotaMax, quotaRemaining, _, quotaRenews := ses.GetQuotaLimitByAPIID(p.TykAPISpec.APIID)
   964  		res.Header.Set(headers.XRateLimitLimit, strconv.Itoa(int(quotaMax)))
   965  		res.Header.Set(headers.XRateLimitRemaining, strconv.Itoa(int(quotaRemaining)))
   966  		res.Header.Set(headers.XRateLimitReset, strconv.Itoa(int(quotaRenews)))
   967  	}
   968  
   969  	copyHeader(rw.Header(), res.Header)
   970  
   971  	announcedTrailers := len(res.Trailer)
   972  	if announcedTrailers > 0 {
   973  		trailerKeys := make([]string, 0, len(res.Trailer))
   974  		for k := range res.Trailer {
   975  			trailerKeys = append(trailerKeys, k)
   976  		}
   977  		rw.Header().Add("Trailer", strings.Join(trailerKeys, ", "))
   978  	}
   979  
   980  	rw.WriteHeader(res.StatusCode)
   981  
   982  	if len(res.Trailer) > 0 {
   983  		// Force chunking if we saw a response trailer.
   984  		// This prevents net/http from calculating the length for short
   985  		// bodies and adding a Content-Length.
   986  		if fl, ok := rw.(http.Flusher); ok {
   987  			fl.Flush()
   988  		}
   989  	}
   990  
   991  	p.CopyResponse(rw, res.Body)
   992  
   993  	if len(res.Trailer) == announcedTrailers {
   994  		copyHeader(rw.Header(), res.Trailer)
   995  		return nil
   996  	}
   997  
   998  	for k, vv := range res.Trailer {
   999  		k = http.TrailerPrefix + k
  1000  		for _, v := range vv {
  1001  			rw.Header().Add(k, v)
  1002  		}
  1003  	}
  1004  	return nil
  1005  }
  1006  
  1007  func (p *ReverseProxy) CopyResponse(dst io.Writer, src io.Reader) {
  1008  	if p.FlushInterval != 0 {
  1009  		if wf, ok := dst.(writeFlusher); ok {
  1010  			mlw := &maxLatencyWriter{
  1011  				dst:     wf,
  1012  				latency: p.FlushInterval,
  1013  				done:    make(chan bool),
  1014  			}
  1015  			go mlw.flushLoop()
  1016  			defer mlw.stop()
  1017  			dst = mlw
  1018  		}
  1019  	}
  1020  
  1021  	p.copyBuffer(dst, src)
  1022  }
  1023  
  1024  func (p *ReverseProxy) copyBuffer(dst io.Writer, src io.Reader) (int64, error) {
  1025  
  1026  	buf := p.sp.Get().(*[]byte)
  1027  	defer p.sp.Put(buf)
  1028  
  1029  	var written int64
  1030  	for {
  1031  		nr, rerr := src.Read(*buf)
  1032  		if rerr != nil && rerr != io.EOF && rerr != context.Canceled {
  1033  			p.logger.WithFields(logrus.Fields{
  1034  				"prefix": "proxy",
  1035  				"org_id": p.TykAPISpec.OrgID,
  1036  				"api_id": p.TykAPISpec.APIID,
  1037  			}).Error("http: proxy error during body copy: ", rerr)
  1038  		}
  1039  		if nr > 0 {
  1040  			nw, werr := dst.Write((*buf)[:nr])
  1041  			if nw > 0 {
  1042  				written += int64(nw)
  1043  			}
  1044  			if werr != nil {
  1045  				return written, werr
  1046  			}
  1047  			if nr != nw {
  1048  				return written, io.ErrShortWrite
  1049  			}
  1050  		}
  1051  		if rerr != nil {
  1052  			return written, rerr
  1053  		}
  1054  	}
  1055  }
  1056  
  1057  func (p *ReverseProxy) handleUpgradeResponse(rw http.ResponseWriter, req *http.Request, res *http.Response) error {
  1058  	copyHeader(res.Header, rw.Header())
  1059  
  1060  	hj, ok := rw.(http.Hijacker)
  1061  	if !ok {
  1062  		return fmt.Errorf("can't switch protocols using non-Hijacker ResponseWriter type %T", rw)
  1063  	}
  1064  	backConn, ok := res.Body.(io.ReadWriteCloser)
  1065  	if !ok {
  1066  		return fmt.Errorf("internal error: 101 switching protocols response with non-writable body")
  1067  	}
  1068  	defer backConn.Close()
  1069  	conn, brw, err := hj.Hijack()
  1070  	if err != nil {
  1071  		return fmt.Errorf("Hijack failed on protocol switch: %v", err)
  1072  	}
  1073  	defer conn.Close()
  1074  	res.Body = nil // so res.Write only writes the headers; we have res.Body in backConn above
  1075  	if err := res.Write(brw); err != nil {
  1076  		return fmt.Errorf("response write: %v", err)
  1077  	}
  1078  	if err := brw.Flush(); err != nil {
  1079  		return fmt.Errorf("response flush: %v", err)
  1080  	}
  1081  	errc := make(chan error, 1)
  1082  	spc := switchProtocolCopier{user: conn, backend: backConn}
  1083  	go spc.copyToBackend(errc)
  1084  	go spc.copyFromBackend(errc)
  1085  	<-errc
  1086  
  1087  	res.Body = ioutil.NopCloser(strings.NewReader(""))
  1088  
  1089  	return nil
  1090  }
  1091  
  1092  // switchProtocolCopier exists so goroutines proxying data back and
  1093  // forth have nice names in stacks.
  1094  type switchProtocolCopier struct {
  1095  	user, backend io.ReadWriter
  1096  }
  1097  
  1098  func (c switchProtocolCopier) copyFromBackend(errc chan<- error) {
  1099  	_, err := io.Copy(c.user, c.backend)
  1100  	errc <- err
  1101  }
  1102  
  1103  func (c switchProtocolCopier) copyToBackend(errc chan<- error) {
  1104  	_, err := io.Copy(c.backend, c.user)
  1105  	errc <- err
  1106  }
  1107  
  1108  type writeFlusher interface {
  1109  	io.Writer
  1110  	http.Flusher
  1111  }
  1112  
  1113  type maxLatencyWriter struct {
  1114  	dst     writeFlusher
  1115  	latency time.Duration
  1116  
  1117  	mu   sync.Mutex // protects Write + Flush
  1118  	done chan bool
  1119  }
  1120  
  1121  func (m *maxLatencyWriter) Write(p []byte) (int, error) {
  1122  	m.mu.Lock()
  1123  	defer m.mu.Unlock()
  1124  	return m.dst.Write(p)
  1125  }
  1126  
  1127  func (m *maxLatencyWriter) flushLoop() {
  1128  	t := time.NewTicker(m.latency)
  1129  	defer t.Stop()
  1130  	for {
  1131  		select {
  1132  		case <-m.done:
  1133  			return
  1134  		case <-t.C:
  1135  			m.mu.Lock()
  1136  			m.dst.Flush()
  1137  			m.mu.Unlock()
  1138  		}
  1139  	}
  1140  }
  1141  
  1142  func (m *maxLatencyWriter) stop() { m.done <- true }
  1143  
  1144  func requestIPHops(r *http.Request) string {
  1145  	clientIP, _, err := net.SplitHostPort(r.RemoteAddr)
  1146  	if err != nil {
  1147  		return ""
  1148  	}
  1149  	// If we aren't the first proxy retain prior
  1150  	// X-Forwarded-For information as a comma+space
  1151  	// separated list and fold multiple headers into one.
  1152  	if prior, ok := r.Header["X-Forwarded-For"]; ok {
  1153  		clientIP = strings.Join(prior, ", ") + ", " + clientIP
  1154  	}
  1155  	return clientIP
  1156  }
  1157  
  1158  // nopCloser is just like ioutil's, but here to let us re-read the same
  1159  // buffer inside by moving position to the start every time we done with reading
  1160  type nopCloser struct {
  1161  	io.ReadSeeker
  1162  }
  1163  
  1164  // Read just a wrapper around real Read which also moves position to the start if we get EOF
  1165  // to have it ready for next read-cycle
  1166  func (n nopCloser) Read(p []byte) (int, error) {
  1167  	num, err := n.ReadSeeker.Read(p)
  1168  	if err == io.EOF { // move to start to have it ready for next read cycle
  1169  		n.Seek(0, io.SeekStart)
  1170  	}
  1171  	return num, err
  1172  }
  1173  
  1174  // Close is a no-op Close
  1175  func (n nopCloser) Close() error {
  1176  	return nil
  1177  }
  1178  
  1179  func copyBody(body io.ReadCloser) io.ReadCloser {
  1180  	// check if body was already read and converted into our nopCloser
  1181  	if nc, ok := body.(nopCloser); ok {
  1182  		// seek to the beginning to have it ready for next read
  1183  		nc.Seek(0, io.SeekStart)
  1184  		return body
  1185  	}
  1186  
  1187  	// body is http's io.ReadCloser - let's close it after we read data
  1188  	defer body.Close()
  1189  
  1190  	// body is http's io.ReadCloser - read it up until EOF
  1191  	var bodyRead bytes.Buffer
  1192  	io.Copy(&bodyRead, body)
  1193  
  1194  	// use seek-able reader for further body usage
  1195  	reusableBody := bytes.NewReader(bodyRead.Bytes())
  1196  
  1197  	return nopCloser{reusableBody}
  1198  }
  1199  
  1200  func copyRequest(r *http.Request) *http.Request {
  1201  	if r.Body != nil {
  1202  		r.Body = copyBody(r.Body)
  1203  	}
  1204  	return r
  1205  }
  1206  
  1207  func copyResponse(r *http.Response) *http.Response {
  1208  	if r.Body != nil {
  1209  		r.Body = copyBody(r.Body)
  1210  	}
  1211  	return r
  1212  }
  1213  
  1214  func nopCloseRequestBody(r *http.Request) {
  1215  	if r == nil {
  1216  		return
  1217  	}
  1218  
  1219  	copyRequest(r)
  1220  }
  1221  
  1222  func nopCloseResponseBody(r *http.Response) {
  1223  	if r == nil {
  1224  		return
  1225  	}
  1226  
  1227  	copyResponse(r)
  1228  }
  1229  
  1230  func IsUpgrade(req *http.Request) (bool, string) {
  1231  	if !config.Global().HttpServerOptions.EnableWebSockets {
  1232  		return false, ""
  1233  	}
  1234  
  1235  	contentType := strings.ToLower(strings.TrimSpace(req.Header.Get(headers.Accept)))
  1236  	if contentType == "text/event-stream" {
  1237  		return true, ""
  1238  	}
  1239  
  1240  	connection := strings.ToLower(strings.TrimSpace(req.Header.Get(headers.Connection)))
  1241  	if connection != "upgrade" {
  1242  		return false, ""
  1243  	}
  1244  
  1245  	upgrade := strings.ToLower(strings.TrimSpace(req.Header.Get("Upgrade")))
  1246  	if upgrade != "" {
  1247  		return true, upgrade
  1248  	}
  1249  
  1250  	return false, ""
  1251  }