github.com/Tyktechnologies/tyk@v2.9.5+incompatible/gateway/proxy_muxer.go (about)

     1  package gateway
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"net/url"
    10  	"strconv"
    11  	"sync"
    12  	"sync/atomic"
    13  	"time"
    14  
    15  	"github.com/TykTechnologies/again"
    16  	"github.com/TykTechnologies/tyk/config"
    17  	"github.com/TykTechnologies/tyk/tcp"
    18  	proxyproto "github.com/pires/go-proxyproto"
    19  	cache "github.com/pmylund/go-cache"
    20  
    21  	"golang.org/x/net/http2"
    22  
    23  	"github.com/gorilla/mux"
    24  	"github.com/sirupsen/logrus"
    25  )
    26  
    27  // handleWrapper's only purpose is to allow router to be dynamically replaced
    28  type handleWrapper struct {
    29  	router *mux.Router
    30  }
    31  
    32  func (h *handleWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    33  	// make request body to be nopCloser and re-readable before serve it through chain of middlewares
    34  	nopCloseRequestBody(r)
    35  	if NewRelicApplication != nil {
    36  		txn := NewRelicApplication.StartTransaction(r.URL.Path, w, r)
    37  		defer txn.End()
    38  		h.router.ServeHTTP(txn, r)
    39  		return
    40  	}
    41  	h.router.ServeHTTP(w, r)
    42  }
    43  
    44  type proxy struct {
    45  	listener         net.Listener
    46  	port             int
    47  	protocol         string
    48  	useProxyProtocol bool
    49  	router           *mux.Router
    50  	httpServer       *http.Server
    51  	tcpProxy         *tcp.Proxy
    52  	started          bool
    53  }
    54  
    55  func (p proxy) String() string {
    56  	ls := ""
    57  	if p.listener != nil {
    58  		ls = p.listener.Addr().String()
    59  	}
    60  	return fmt.Sprintf("[proxy] :%d %s", p.port, ls)
    61  }
    62  
    63  // getListener returns a net.Listener for this proxy. If useProxyProtocol is
    64  // true it wraps the underlying listener to support proxyprotocol.
    65  func (p proxy) getListener() net.Listener {
    66  	if p.useProxyProtocol {
    67  		return &proxyproto.Listener{Listener: p.listener}
    68  	}
    69  	return p.listener
    70  }
    71  
    72  type proxyMux struct {
    73  	sync.RWMutex
    74  	proxies []*proxy
    75  	again   again.Again
    76  }
    77  
    78  var defaultProxyMux = &proxyMux{
    79  	again: again.New(),
    80  }
    81  
    82  func (m *proxyMux) getProxy(listenPort int) *proxy {
    83  	if listenPort == 0 {
    84  		listenPort = config.Global().ListenPort
    85  	}
    86  
    87  	for _, p := range m.proxies {
    88  		if p.port == listenPort {
    89  			return p
    90  		}
    91  	}
    92  
    93  	return nil
    94  }
    95  
    96  func (m *proxyMux) router(port int, protocol string) *mux.Router {
    97  	if protocol == "" {
    98  		if config.Global().HttpServerOptions.UseSSL {
    99  			protocol = "https"
   100  		} else {
   101  			protocol = "http"
   102  		}
   103  	}
   104  
   105  	if proxy := m.getProxy(port); proxy != nil {
   106  		if proxy.protocol != protocol {
   107  			mainLog.WithField("port", port).Warningf("Can't get router for protocol %s, router for protocol %s found", protocol, proxy.protocol)
   108  			return nil
   109  		}
   110  
   111  		return proxy.router
   112  	}
   113  
   114  	return nil
   115  }
   116  
   117  func (m *proxyMux) setRouter(port int, protocol string, router *mux.Router) {
   118  	if port == 0 {
   119  		port = config.Global().ListenPort
   120  	}
   121  
   122  	if protocol == "" {
   123  		if config.Global().HttpServerOptions.UseSSL {
   124  			protocol = "https"
   125  		} else {
   126  			protocol = "http"
   127  		}
   128  	}
   129  
   130  	router.SkipClean(config.Global().HttpServerOptions.SkipURLCleaning)
   131  	p := m.getProxy(port)
   132  	if p == nil {
   133  		p = &proxy{
   134  			port:     port,
   135  			protocol: protocol,
   136  			router:   router,
   137  		}
   138  		m.proxies = append(m.proxies, p)
   139  	} else {
   140  		if p.protocol != protocol {
   141  			mainLog.WithFields(logrus.Fields{
   142  				"port":     port,
   143  				"protocol": protocol,
   144  			}).Warningf("Can't update router. Already found service with another protocol %s", p.protocol)
   145  			return
   146  		}
   147  		p.router = router
   148  	}
   149  }
   150  
   151  func (m *proxyMux) addTCPService(spec *APISpec, modifier *tcp.Modifier) {
   152  	hostname := spec.GlobalConfig.HostName
   153  	if spec.GlobalConfig.EnableCustomDomains {
   154  		hostname = spec.Domain
   155  	} else {
   156  		hostname = ""
   157  	}
   158  
   159  	if p := m.getProxy(spec.ListenPort); p != nil {
   160  		p.tcpProxy.AddDomainHandler(hostname, spec.Proxy.TargetURL, modifier)
   161  	} else {
   162  		tlsConfig := tlsClientConfig(spec)
   163  
   164  		p = &proxy{
   165  			port:             spec.ListenPort,
   166  			protocol:         spec.Protocol,
   167  			useProxyProtocol: spec.EnableProxyProtocol,
   168  			tcpProxy: &tcp.Proxy{
   169  				DialTLS:         dialWithServiceDiscovery(spec, customDialTLSCheck(spec, tlsConfig)),
   170  				Dial:            dialWithServiceDiscovery(spec, net.Dial),
   171  				TLSConfigTarget: tlsConfig,
   172  				// SyncStats:       recordTCPHit(spec.APIID, spec.DoNotTrack),
   173  			},
   174  		}
   175  		p.tcpProxy.AddDomainHandler(hostname, spec.Proxy.TargetURL, modifier)
   176  		m.proxies = append(m.proxies, p)
   177  	}
   178  }
   179  
   180  func flushNetworkAnalytics(ctx context.Context) {
   181  	mainLog.Debug("Starting routine for flushing network analytics")
   182  	tick := time.NewTicker(time.Second)
   183  	defer tick.Stop()
   184  	for {
   185  		select {
   186  		case <-ctx.Done():
   187  			return
   188  		case t := <-tick.C:
   189  
   190  			apisMu.RLock()
   191  			for _, spec := range apiSpecs {
   192  				switch spec.Protocol {
   193  				case "tcp", "tls":
   194  					// we only flush network analytics for these services
   195  				default:
   196  					continue
   197  				}
   198  				if spec.DoNotTrack {
   199  					continue
   200  				}
   201  				record := AnalyticsRecord{
   202  					Network:      spec.network.Flush(),
   203  					Day:          t.Day(),
   204  					Month:        t.Month(),
   205  					Year:         t.Year(),
   206  					Hour:         t.Hour(),
   207  					ResponseCode: -1,
   208  					TimeStamp:    t,
   209  					APIName:      spec.Name,
   210  					APIID:        spec.APIID,
   211  					OrgID:        spec.OrgID,
   212  				}
   213  				record.SetExpiry(spec.ExpireAnalyticsAfter)
   214  				analytics.RecordHit(&record)
   215  			}
   216  			apisMu.RUnlock()
   217  		}
   218  	}
   219  }
   220  
   221  func recordTCPHit(specID string, doNotTrack bool) func(tcp.Stat) {
   222  	if doNotTrack {
   223  		return nil
   224  	}
   225  	return func(stat tcp.Stat) {
   226  		// Between reloads, pointers to the actual spec might have changed. The spec
   227  		// id stays the same so we need to pic the latest refence to the spec and
   228  		// update network stats.
   229  		apisMu.RLock()
   230  		spec := apisByID[specID]
   231  		apisMu.RUnlock()
   232  		switch stat.State {
   233  		case tcp.Open:
   234  			atomic.AddInt64(&spec.network.OpenConnections, 1)
   235  		case tcp.Closed:
   236  			atomic.AddInt64(&spec.network.ClosedConnection, 1)
   237  		}
   238  		atomic.AddInt64(&spec.network.BytesIn, stat.BytesIn)
   239  		atomic.AddInt64(&spec.network.BytesOut, stat.BytesOut)
   240  	}
   241  }
   242  
   243  type dialFn func(network string, address string) (net.Conn, error)
   244  
   245  func dialWithServiceDiscovery(spec *APISpec, dial dialFn) dialFn {
   246  	if dial == nil {
   247  		return nil
   248  	}
   249  	if spec.Proxy.ServiceDiscovery.UseDiscoveryService {
   250  		log.Debug("[PROXY] Service discovery enabled")
   251  		if ServiceCache == nil {
   252  			log.Debug("[PROXY] Service cache initialising")
   253  			expiry := 120
   254  			if spec.Proxy.ServiceDiscovery.CacheTimeout > 0 {
   255  				expiry = int(spec.Proxy.ServiceDiscovery.CacheTimeout)
   256  			} else if spec.GlobalConfig.ServiceDiscovery.DefaultCacheTimeout > 0 {
   257  				expiry = spec.GlobalConfig.ServiceDiscovery.DefaultCacheTimeout
   258  			}
   259  			ServiceCache = cache.New(time.Duration(expiry)*time.Second, 15*time.Second)
   260  		}
   261  	}
   262  	return func(network, address string) (net.Conn, error) {
   263  		hostList := spec.Proxy.StructuredTargetList
   264  		target := address
   265  		switch {
   266  		case spec.Proxy.ServiceDiscovery.UseDiscoveryService:
   267  			var err error
   268  			hostList, err = urlFromService(spec)
   269  			if err != nil {
   270  				log.Error("[PROXY] [SERVICE DISCOVERY] Failed target lookup: ", err)
   271  				break
   272  			}
   273  			log.Debug("[PROXY] [SERVICE DISCOVERY] received host list ", hostList.All())
   274  			fallthrough // implies load balancing, with replaced host list
   275  		case spec.Proxy.EnableLoadBalancing:
   276  			host, err := nextTarget(hostList, spec)
   277  			if err != nil {
   278  				log.Error("[PROXY] [LOAD BALANCING] ", err)
   279  				host = allHostsDownURL
   280  			}
   281  			lbRemote, err := url.Parse(host)
   282  			if err != nil {
   283  				log.Error("[PROXY] [LOAD BALANCING] Couldn't parse target URL:", err)
   284  			} else {
   285  				if lbRemote.Scheme == network {
   286  					target = lbRemote.Host
   287  				} else {
   288  					log.Errorf("[PROXY] [LOAD BALANCING] mis match scheme want:%s got: %s", network, lbRemote.Scheme)
   289  				}
   290  			}
   291  		}
   292  		return dial(network, target)
   293  	}
   294  }
   295  
   296  func (m *proxyMux) swap(new *proxyMux) {
   297  	m.Lock()
   298  	defer m.Unlock()
   299  	listenAddress := config.Global().ListenAddress
   300  
   301  	// Shutting down and removing unused listeners/proxies
   302  	i := 0
   303  	for _, curP := range m.proxies {
   304  		match := new.getProxy(curP.port)
   305  		if match == nil || match.protocol != curP.protocol {
   306  			mainLog.Infof("Found unused listener at port %d, shutting down", curP.port)
   307  
   308  			if curP.httpServer != nil {
   309  				ctx, cancel := context.WithTimeout(context.Background(), time.Second*10)
   310  				curP.httpServer.Shutdown(ctx)
   311  				cancel()
   312  			} else if curP.listener != nil {
   313  				curP.listener.Close()
   314  			}
   315  			m.again.Delete(target(listenAddress, curP.port))
   316  		} else {
   317  			m.proxies[i] = curP
   318  			i++
   319  		}
   320  	}
   321  	m.proxies = m.proxies[:i]
   322  
   323  	// Replacing existing routers or starting new listeners
   324  	for _, newP := range new.proxies {
   325  		match := m.getProxy(newP.port)
   326  		if match == nil {
   327  			m.proxies = append(m.proxies, newP)
   328  		} else {
   329  			if match.tcpProxy != nil {
   330  				match.tcpProxy.Swap(newP.tcpProxy)
   331  			}
   332  			match.router = newP.router
   333  			if match.httpServer != nil {
   334  				match.httpServer.Handler.(*handleWrapper).router = newP.router
   335  			}
   336  		}
   337  	}
   338  	p := m.getProxy(config.Global().ListenPort)
   339  	if p != nil && p.router != nil {
   340  		// All APIs processed, now we can healthcheck
   341  		// Add a root message to check all is OK
   342  		p.router.HandleFunc("/"+config.Global().HealthCheckEndpointName, func(w http.ResponseWriter, r *http.Request) {
   343  			fmt.Fprint(w, "Hello Tiki")
   344  		})
   345  	}
   346  	m.serve()
   347  }
   348  
   349  func (m *proxyMux) serve() {
   350  	for _, p := range m.proxies {
   351  		if p.listener == nil {
   352  			listener, err := m.generateListener(p.port, p.protocol)
   353  			if err != nil {
   354  				mainLog.WithError(err).Error("Can't start listener")
   355  				continue
   356  			}
   357  
   358  			_, portS, _ := net.SplitHostPort(listener.Addr().String())
   359  			port, _ := strconv.Atoi(portS)
   360  			p.port = port
   361  			p.listener = listener
   362  		}
   363  		if p.started {
   364  			continue
   365  		}
   366  
   367  		switch p.protocol {
   368  		case "tcp", "tls":
   369  			mainLog.Warning("Starting TCP server on:", p.listener.Addr().String())
   370  			go p.tcpProxy.Serve(p.getListener())
   371  		case "http", "https":
   372  			mainLog.Warning("Starting HTTP server on:", p.listener.Addr().String())
   373  			readTimeout := 120 * time.Second
   374  			writeTimeout := 120 * time.Second
   375  
   376  			if config.Global().HttpServerOptions.ReadTimeout > 0 {
   377  				readTimeout = time.Duration(config.Global().HttpServerOptions.ReadTimeout) * time.Second
   378  			}
   379  
   380  			if config.Global().HttpServerOptions.WriteTimeout > 0 {
   381  				writeTimeout = time.Duration(config.Global().HttpServerOptions.WriteTimeout) * time.Second
   382  			}
   383  
   384  			addr := config.Global().ListenAddress + ":" + strconv.Itoa(p.port)
   385  			p.httpServer = &http.Server{
   386  				Addr:         addr,
   387  				ReadTimeout:  readTimeout,
   388  				WriteTimeout: writeTimeout,
   389  				Handler:      &handleWrapper{p.router},
   390  			}
   391  
   392  			if config.Global().CloseConnections {
   393  				p.httpServer.SetKeepAlivesEnabled(false)
   394  			}
   395  
   396  			go p.httpServer.Serve(p.listener)
   397  		}
   398  
   399  		p.started = true
   400  	}
   401  }
   402  
   403  func target(listenAddress string, listenPort int) string {
   404  	return fmt.Sprintf("%s:%d", listenAddress, listenPort)
   405  }
   406  
   407  func CheckPortWhiteList(w map[string]config.PortWhiteList, listenPort int, protocol string) error {
   408  	if w != nil {
   409  		if ls, ok := w[protocol]; ok {
   410  			if ls.Match(listenPort) {
   411  				return nil
   412  			}
   413  		}
   414  	}
   415  	return fmt.Errorf("%s:%d trying to open disabled port", protocol, listenPort)
   416  }
   417  
   418  func (m *proxyMux) generateListener(listenPort int, protocol string) (l net.Listener, err error) {
   419  	listenAddress := config.Global().ListenAddress
   420  	if !config.Global().DisablePortWhiteList {
   421  		if err := CheckPortWhiteList(config.Global().PortWhiteList, listenPort, protocol); err != nil {
   422  			return nil, err
   423  		}
   424  	}
   425  
   426  	targetPort := listenAddress + ":" + strconv.Itoa(listenPort)
   427  	if ls := m.again.GetListener(targetPort); ls != nil {
   428  		return ls, nil
   429  	}
   430  	switch protocol {
   431  	case "https", "tls":
   432  		mainLog.Infof("--> Using TLS (%s)", protocol)
   433  		httpServerOptions := config.Global().HttpServerOptions
   434  
   435  		tlsConfig := tls.Config{
   436  			GetCertificate:     dummyGetCertificate,
   437  			ServerName:         httpServerOptions.ServerName,
   438  			MinVersion:         httpServerOptions.MinVersion,
   439  			ClientAuth:         tls.NoClientCert,
   440  			InsecureSkipVerify: httpServerOptions.SSLInsecureSkipVerify,
   441  			CipherSuites:       getCipherAliases(httpServerOptions.Ciphers),
   442  		}
   443  
   444  		if httpServerOptions.EnableHttp2 {
   445  			tlsConfig.NextProtos = append(tlsConfig.NextProtos, http2.NextProtoTLS)
   446  		}
   447  
   448  		tlsConfig.GetConfigForClient = getTLSConfigForClient(&tlsConfig, listenPort)
   449  		l, err = tls.Listen("tcp", targetPort, &tlsConfig)
   450  	default:
   451  		mainLog.WithField("port", targetPort).Infof("--> Standard listener (%s)", protocol)
   452  		l, err = net.Listen("tcp", targetPort)
   453  	}
   454  	if err != nil {
   455  		return nil, err
   456  	}
   457  	if err := (&m.again).Listen(targetPort, l); err != nil {
   458  		return nil, err
   459  	}
   460  	return l, nil
   461  }