github.com/lmb/consul@v1.4.1/connect/proxy/listener.go (about) 1 package proxy 2 3 import ( 4 "context" 5 "crypto/tls" 6 "errors" 7 "fmt" 8 "log" 9 "net" 10 "sync" 11 "sync/atomic" 12 "time" 13 14 metrics "github.com/armon/go-metrics" 15 "github.com/hashicorp/consul/api" 16 "github.com/hashicorp/consul/connect" 17 ) 18 19 const ( 20 publicListenerMetricPrefix = "inbound" 21 upstreamMetricPrefix = "upstream" 22 ) 23 24 // Listener is the implementation of a specific proxy listener. It has pluggable 25 // Listen and Dial methods to suit public mTLS vs upstream semantics. It handles 26 // the lifecycle of the listener and all connections opened through it 27 type Listener struct { 28 // Service is the connect service instance to use. 29 Service *connect.Service 30 31 // listenFunc, dialFunc, and bindAddr are set by type-specific constructors. 32 listenFunc func() (net.Listener, error) 33 dialFunc func() (net.Conn, error) 34 bindAddr string 35 36 stopFlag int32 37 stopChan chan struct{} 38 39 // listeningChan is closed when listener is opened successfully. It's really 40 // only for use in tests where we need to coordinate wait for the Serve 41 // goroutine to be running before we proceed trying to connect. On my laptop 42 // this always works out anyway but on constrained VMs and especially docker 43 // containers (e.g. in CI) we often see the Dial routine win the race and get 44 // `connection refused`. Retry loops and sleeps are unpleasant workarounds and 45 // this is cheap and correct. 46 listeningChan chan struct{} 47 48 logger *log.Logger 49 50 // Gauge to track current open connections 51 activeConns int32 52 connWG sync.WaitGroup 53 metricPrefix string 54 metricLabels []metrics.Label 55 } 56 57 // NewPublicListener returns a Listener setup to listen for public mTLS 58 // connections and proxy them to the configured local application over TCP. 59 func NewPublicListener(svc *connect.Service, cfg PublicListenerConfig, 60 logger *log.Logger) *Listener { 61 bindAddr := fmt.Sprintf("%s:%d", cfg.BindAddress, cfg.BindPort) 62 return &Listener{ 63 Service: svc, 64 listenFunc: func() (net.Listener, error) { 65 return tls.Listen("tcp", bindAddr, svc.ServerTLSConfig()) 66 }, 67 dialFunc: func() (net.Conn, error) { 68 return net.DialTimeout("tcp", cfg.LocalServiceAddress, 69 time.Duration(cfg.LocalConnectTimeoutMs)*time.Millisecond) 70 }, 71 bindAddr: bindAddr, 72 stopChan: make(chan struct{}), 73 listeningChan: make(chan struct{}), 74 logger: logger, 75 metricPrefix: publicListenerMetricPrefix, 76 // For now we only label ourselves as source - we could fetch the src 77 // service from cert on each connection and label metrics differently but it 78 // significaly complicates the active connection tracking here and it's not 79 // clear that it's very valuable - on aggregate looking at all _outbound_ 80 // connections across all proxies gets you a full picture of src->dst 81 // traffic. We might expand this later for better debugging of which clients 82 // are abusing a particular service instance but we'll see how valuable that 83 // seems for the extra complication of tracking many gauges here. 84 metricLabels: []metrics.Label{{Name: "dst", Value: svc.Name()}}, 85 } 86 } 87 88 // NewUpstreamListener returns a Listener setup to listen locally for TCP 89 // connections that are proxied to a discovered Connect service instance. 90 func NewUpstreamListener(svc *connect.Service, client *api.Client, 91 cfg UpstreamConfig, logger *log.Logger) *Listener { 92 return newUpstreamListenerWithResolver(svc, cfg, 93 UpstreamResolverFuncFromClient(client), logger) 94 } 95 96 func newUpstreamListenerWithResolver(svc *connect.Service, cfg UpstreamConfig, 97 resolverFunc func(UpstreamConfig) (connect.Resolver, error), 98 logger *log.Logger) *Listener { 99 bindAddr := fmt.Sprintf("%s:%d", cfg.LocalBindAddress, cfg.LocalBindPort) 100 return &Listener{ 101 Service: svc, 102 listenFunc: func() (net.Listener, error) { 103 return net.Listen("tcp", bindAddr) 104 }, 105 dialFunc: func() (net.Conn, error) { 106 rf, err := resolverFunc(cfg) 107 if err != nil { 108 return nil, err 109 } 110 ctx, cancel := context.WithTimeout(context.Background(), 111 cfg.ConnectTimeout()) 112 defer cancel() 113 return svc.Dial(ctx, rf) 114 }, 115 bindAddr: bindAddr, 116 stopChan: make(chan struct{}), 117 listeningChan: make(chan struct{}), 118 logger: logger, 119 metricPrefix: upstreamMetricPrefix, 120 metricLabels: []metrics.Label{ 121 {Name: "src", Value: svc.Name()}, 122 // TODO(banks): namespace support 123 {Name: "dst_type", Value: string(cfg.DestinationType)}, 124 {Name: "dst", Value: cfg.DestinationName}, 125 }, 126 } 127 } 128 129 // Serve runs the listener until it is stopped. It is an error to call Serve 130 // more than once for any given Listener instance. 131 func (l *Listener) Serve() error { 132 // Ensure we mark state closed if we fail before Close is called externally. 133 defer l.Close() 134 135 if atomic.LoadInt32(&l.stopFlag) != 0 { 136 return errors.New("serve called on a closed listener") 137 } 138 139 listen, err := l.listenFunc() 140 if err != nil { 141 return err 142 } 143 close(l.listeningChan) 144 145 for { 146 conn, err := listen.Accept() 147 if err != nil { 148 if atomic.LoadInt32(&l.stopFlag) == 1 { 149 return nil 150 } 151 return err 152 } 153 154 go l.handleConn(conn) 155 } 156 } 157 158 // handleConn is the internal connection handler goroutine. 159 func (l *Listener) handleConn(src net.Conn) { 160 defer src.Close() 161 162 dst, err := l.dialFunc() 163 if err != nil { 164 l.logger.Printf("[ERR] failed to dial: %s", err) 165 return 166 } 167 168 // Track active conn now (first function call) and defer un-counting it when 169 // it closes. 170 defer l.trackConn()() 171 172 // Make sure Close() waits for this conn to be cleaned up. Note defer is 173 // before conn.Close() so runs after defer conn.Close(). 174 l.connWG.Add(1) 175 defer l.connWG.Done() 176 177 // Note no need to defer dst.Close() since conn handles that for us. 178 conn := NewConn(src, dst) 179 defer conn.Close() 180 181 connStop := make(chan struct{}) 182 183 // Run another goroutine to copy the bytes. 184 go func() { 185 err = conn.CopyBytes() 186 if err != nil { 187 l.logger.Printf("[ERR] connection failed: %s", err) 188 } 189 close(connStop) 190 }() 191 192 // Periodically copy stats from conn to metrics (to keep metrics calls out of 193 // the path of every single packet copy). 5 seconds is probably good enough 194 // resolution - statsd and most others tend to summarize with lower resolution 195 // anyway and this amortizes the cost more. 196 var tx, rx uint64 197 statsT := time.NewTicker(5 * time.Second) 198 defer statsT.Stop() 199 200 reportStats := func() { 201 newTx, newRx := conn.Stats() 202 if delta := newTx - tx; delta > 0 { 203 metrics.IncrCounterWithLabels([]string{l.metricPrefix, "tx_bytes"}, 204 float32(newTx-tx), l.metricLabels) 205 } 206 if delta := newRx - rx; delta > 0 { 207 metrics.IncrCounterWithLabels([]string{l.metricPrefix, "rx_bytes"}, 208 float32(newRx-rx), l.metricLabels) 209 } 210 tx, rx = newTx, newRx 211 } 212 // Always report final stats for the conn. 213 defer reportStats() 214 215 // Wait for conn to close 216 for { 217 select { 218 case <-connStop: 219 return 220 case <-l.stopChan: 221 return 222 case <-statsT.C: 223 reportStats() 224 } 225 } 226 } 227 228 // trackConn increments the count of active conns and returns a func() that can 229 // be deferred on to decrement the counter again on connection close. 230 func (l *Listener) trackConn() func() { 231 c := atomic.AddInt32(&l.activeConns, 1) 232 metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c), 233 l.metricLabels) 234 235 return func() { 236 c := atomic.AddInt32(&l.activeConns, -1) 237 metrics.SetGaugeWithLabels([]string{l.metricPrefix, "conns"}, float32(c), 238 l.metricLabels) 239 } 240 } 241 242 // Close terminates the listener and all active connections. 243 func (l *Listener) Close() error { 244 oldFlag := atomic.SwapInt32(&l.stopFlag, 1) 245 if oldFlag == 0 { 246 close(l.stopChan) 247 // Wait for all conns to close 248 l.connWG.Wait() 249 } 250 return nil 251 } 252 253 // Wait for the listener to be ready to accept connections. 254 func (l *Listener) Wait() { 255 <-l.listeningChan 256 } 257 258 // BindAddr returns the address the listen is bound to. 259 func (l *Listener) BindAddr() string { 260 return l.bindAddr 261 }