github.com/kjdelisle/consul@v1.4.5/connect/proxy/listener.go (about)

     1  package proxy
     2  
     3  import (
     4  	"context"
     5  	"crypto/tls"
     6  	"errors"
     7  	"fmt"
     8  	"log"
     9  	"net"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	metrics "github.com/armon/go-metrics"
    15  	"github.com/hashicorp/consul/api"
    16  	"github.com/hashicorp/consul/connect"
    17  )
    18  
    19  const (
    20  	publicListenerMetricPrefix = "inbound"
    21  	upstreamMetricPrefix       = "upstream"
    22  )
    23  
    24  // Listener is the implementation of a specific proxy listener. It has pluggable
    25  // Listen and Dial methods to suit public mTLS vs upstream semantics. It handles
    26  // the lifecycle of the listener and all connections opened through it
    27  type Listener struct {
    28  	// Service is the connect service instance to use.
    29  	Service *connect.Service
    30  
    31  	// listenFunc, dialFunc, and bindAddr are set by type-specific constructors.
    32  	listenFunc func() (net.Listener, error)
    33  	dialFunc   func() (net.Conn, error)
    34  	bindAddr   string
    35  
    36  	stopFlag int32
    37  	stopChan chan struct{}
    38  
    39  	// listeningChan is closed when listener is opened successfully. It's really
    40  	// only for use in tests where we need to coordinate wait for the Serve
    41  	// goroutine to be running before we proceed trying to connect. On my laptop
    42  	// this always works out anyway but on constrained VMs and especially docker
    43  	// containers (e.g. in CI) we often see the Dial routine win the race and get
    44  	// `connection refused`. Retry loops and sleeps are unpleasant workarounds and
    45  	// this is cheap and correct.
    46  	listeningChan chan struct{}
    47  
    48  	logger *log.Logger
    49  
    50  	// Gauge to track current open connections
    51  	activeConns  int32
    52  	connWG       sync.WaitGroup
    53  	metricPrefix string
    54  	metricLabels []metrics.Label
    55  }
    56  
    57  // NewPublicListener returns a Listener setup to listen for public mTLS
    58  // connections and proxy them to the configured local application over TCP.
    59  func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig,
    60  	logger *log.Logger) *Listener {
    61  	bindAddr := fmt.Sprintf("%s:%d", cfg.BindAddress, cfg.BindPort)
    62  	return &Listener{
    63  		Service: svc,
    64  		listenFunc: func() (net.Listener, error) {
    65  			return tls.Listen("tcp", bindAddr, svc.ServerTLSConfig())
    66  		},
    67  		dialFunc: func() (net.Conn, error) {
    68  			return net.DialTimeout("tcp", cfg.LocalServiceAddress,
    69  				time.Duration(cfg.LocalConnectTimeoutMs)*time.Millisecond)
    70  		},
    71  		bindAddr:      bindAddr,
    72  		stopChan:      make(chan struct{}),
    73  		listeningChan: make(chan struct{}),
    74  		logger:        logger,
    75  		metricPrefix:  publicListenerMetricPrefix,
    76  		// For now we only label ourselves as source - we could fetch the src
    77  		// service from cert on each connection and label metrics differently but it
    78  		// significaly complicates the active connection tracking here and it's not
    79  		// clear that it's very valuable - on aggregate looking at all _outbound_
    80  		// connections across all proxies gets you a full picture of src->dst
    81  		// traffic. We might expand this later for better debugging of which clients
    82  		// are abusing a particular service instance but we'll see how valuable that
    83  		// seems for the extra complication of tracking many gauges here.
    84  		metricLabels: []metrics.Label{{Name: "dst", Value: svc.Name()}},
    85  	}
    86  }
    87  
    88  // NewUpstreamListener returns a Listener setup to listen locally for TCP
    89  // connections that are proxied to a discovered Connect service instance.
    90  func NewUpstreamListener(svc *connect.Service, client *api.Client,
    91  	cfg UpstreamConfig, logger *log.Logger) *Listener {
    92  	return newUpstreamListenerWithResolver(svc, cfg,
    93  		UpstreamResolverFuncFromClient(client), logger)
    94  }
    95  
    96  func newUpstreamListenerWithResolver(svc *connect.Service, cfg UpstreamConfig,
    97  	resolverFunc func(UpstreamConfig) (connect.Resolver, error),
    98  	logger *log.Logger) *Listener {
    99  	bindAddr := fmt.Sprintf("%s:%d", cfg.LocalBindAddress, cfg.LocalBindPort)
   100  	return &Listener{
   101  		Service: svc,
   102  		listenFunc: func() (net.Listener, error) {
   103  			return net.Listen("tcp", bindAddr)
   104  		},
   105  		dialFunc: func() (net.Conn, error) {
   106  			rf, err := resolverFunc(cfg)
   107  			if err != nil {
   108  				return nil, err
   109  			}
   110  			ctx, cancel := context.WithTimeout(context.Background(),
   111  				cfg.ConnectTimeout())
   112  			defer cancel()
   113  			return svc.Dial(ctx, rf)
   114  		},
   115  		bindAddr:      bindAddr,
   116  		stopChan:      make(chan struct{}),
   117  		listeningChan: make(chan struct{}),
   118  		logger:        logger,
   119  		metricPrefix:  upstreamMetricPrefix,
   120  		metricLabels: []metrics.Label{
   121  			{Name: "src", Value: svc.Name()},
   122  			// TODO(banks): namespace support
   123  			{Name: "dst_type", Value: string(cfg.DestinationType)},
   124  			{Name: "dst", Value: cfg.DestinationName},
   125  		},
   126  	}
   127  }
   128  
   129  // Serve runs the listener until it is stopped. It is an error to call Serve
   130  // more than once for any given Listener instance.
   131  func (l *Listener) Serve() error {
   132  	// Ensure we mark state closed if we fail before Close is called externally.
   133  	defer l.Close()
   134  
   135  	if atomic.LoadInt32(&l.stopFlag) != 0 {
   136  		return errors.New("serve called on a closed listener")
   137  	}
   138  
   139  	listen, err := l.listenFunc()
   140  	if err != nil {
   141  		return err
   142  	}
   143  	close(l.listeningChan)
   144  
   145  	for {
   146  		conn, err := listen.Accept()
   147  		if err != nil {
   148  			if atomic.LoadInt32(&l.stopFlag) == 1 {
   149  				return nil
   150  			}
   151  			return err
   152  		}
   153  
   154  		go l.handleConn(conn)
   155  	}
   156  }
   157  
   158  // handleConn is the internal connection handler goroutine.
   159  func (l *Listener) handleConn(src net.Conn) {
   160  	defer src.Close()
   161  
   162  	dst, err := l.dialFunc()
   163  	if err != nil {
   164  		l.logger.Printf("[ERR] failed to dial: %s", err)
   165  		return
   166  	}
   167  
   168  	// Track active conn now (first function call) and defer un-counting it when
   169  	// it closes.
   170  	defer l.trackConn()()
   171  
   172  	// Make sure Close() waits for this conn to be cleaned up. Note defer is
   173  	// before conn.Close() so runs after defer conn.Close().
   174  	l.connWG.Add(1)
   175  	defer l.connWG.Done()
   176  
   177  	// Note no need to defer dst.Close() since conn handles that for us.
   178  	conn := NewConn(src, dst)
   179  	defer conn.Close()
   180  
   181  	connStop := make(chan struct{})
   182  
   183  	// Run another goroutine to copy the bytes.
   184  	go func() {
   185  		err = conn.CopyBytes()
   186  		if err != nil {
   187  			l.logger.Printf("[ERR] connection failed: %s", err)
   188  		}
   189  		close(connStop)
   190  	}()
   191  
   192  	// Periodically copy stats from conn to metrics (to keep metrics calls out of
   193  	// the path of every single packet copy). 5 seconds is probably good enough
   194  	// resolution - statsd and most others tend to summarize with lower resolution
   195  	// anyway and this amortizes the cost more.
   196  	var tx, rx uint64
   197  	statsT := time.NewTicker(5 * time.Second)
   198  	defer statsT.Stop()
   199  
   200  	reportStats := func() {
   201  		newTx, newRx := conn.Stats()
   202  		if delta := newTx - tx; delta > 0 {
   203  			metrics.IncrCounterWithLabels([]string{l.metricPrefix, "tx_bytes"},
   204  				float32(newTx-tx), l.metricLabels)
   205  		}
   206  		if delta := newRx - rx; delta > 0 {
   207  			metrics.IncrCounterWithLabels([]string{l.metricPrefix, "rx_bytes"},
   208  				float32(newRx-rx), l.metricLabels)
   209  		}
   210  		tx, rx = newTx, newRx
   211  	}
   212  	// Always report final stats for the conn.
   213  	defer reportStats()
   214  
   215  	// Wait for conn to close
   216  	for {
   217  		select {
   218  		case <-connStop:
   219  			return
   220  		case <-l.stopChan:
   221  			return
   222  		case <-statsT.C:
   223  			reportStats()
   224  		}
   225  	}
   226  }
   227  
   228  // trackConn increments the count of active conns and returns a func() that can
   229  // be deferred on to decrement the counter again on connection close.
   230  func (l *Listener) trackConn() func() {
   231  	c := atomic.AddInt32(&l.activeConns, 1)
   232  	metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c),
   233  		l.metricLabels)
   234  
   235  	return func() {
   236  		c := atomic.AddInt32(&l.activeConns, -1)
   237  		metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c),
   238  			l.metricLabels)
   239  	}
   240  }
   241  
   242  // Close terminates the listener and all active connections.
   243  func (l *Listener) Close() error {
   244  	oldFlag := atomic.SwapInt32(&l.stopFlag, 1)
   245  	if oldFlag == 0 {
   246  		close(l.stopChan)
   247  		// Wait for all conns to close
   248  		l.connWG.Wait()
   249  	}
   250  	return nil
   251  }
   252  
   253  // Wait for the listener to be ready to accept connections.
   254  func (l *Listener) Wait() {
   255  	<-l.listeningChan
   256  }
   257  
   258  // BindAddr returns the address the listen is bound to.
   259  func (l *Listener) BindAddr() string {
   260  	return l.bindAddr
   261  }