github.com/outbrain/consul@v1.4.5/agent/pool/pool.go (about)

     1  package pool
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"net/rpc"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/hashicorp/consul/lib"
    14  	"github.com/hashicorp/consul/tlsutil"
    15  	"github.com/hashicorp/net-rpc-msgpackrpc"
    16  	"github.com/hashicorp/yamux"
    17  )
    18  
    19  const defaultDialTimeout = 10 * time.Second
    20  
    21  // muxSession is used to provide an interface for a stream multiplexer.
    22  type muxSession interface {
    23  	Open() (net.Conn, error)
    24  	Close() error
    25  }
    26  
    27  // streamClient is used to wrap a stream with an RPC client
    28  type StreamClient struct {
    29  	stream net.Conn
    30  	codec  rpc.ClientCodec
    31  }
    32  
    33  func (sc *StreamClient) Close() {
    34  	sc.stream.Close()
    35  	sc.codec.Close()
    36  }
    37  
    38  // Conn is a pooled connection to a Consul server
    39  type Conn struct {
    40  	refCount    int32
    41  	shouldClose int32
    42  
    43  	addr     net.Addr
    44  	session  muxSession
    45  	lastUsed time.Time
    46  	version  int
    47  
    48  	pool *ConnPool
    49  
    50  	clients    *list.List
    51  	clientLock sync.Mutex
    52  }
    53  
    54  func (c *Conn) Close() error {
    55  	return c.session.Close()
    56  }
    57  
    58  // getClient is used to get a cached or new client
    59  func (c *Conn) getClient() (*StreamClient, error) {
    60  	// Check for cached client
    61  	c.clientLock.Lock()
    62  	front := c.clients.Front()
    63  	if front != nil {
    64  		c.clients.Remove(front)
    65  	}
    66  	c.clientLock.Unlock()
    67  	if front != nil {
    68  		return front.Value.(*StreamClient), nil
    69  	}
    70  
    71  	// Open a new session
    72  	stream, err := c.session.Open()
    73  	if err != nil {
    74  		return nil, err
    75  	}
    76  
    77  	// Create the RPC client
    78  	codec := msgpackrpc.NewClientCodec(stream)
    79  
    80  	// Return a new stream client
    81  	sc := &StreamClient{
    82  		stream: stream,
    83  		codec:  codec,
    84  	}
    85  	return sc, nil
    86  }
    87  
    88  // returnStream is used when done with a stream
    89  // to allow re-use by a future RPC
    90  func (c *Conn) returnClient(client *StreamClient) {
    91  	didSave := false
    92  	c.clientLock.Lock()
    93  	if c.clients.Len() < c.pool.MaxStreams && atomic.LoadInt32(&c.shouldClose) == 0 {
    94  		c.clients.PushFront(client)
    95  		didSave = true
    96  
    97  		// If this is a Yamux stream, shrink the internal buffers so that
    98  		// we can GC the idle memory
    99  		if ys, ok := client.stream.(*yamux.Stream); ok {
   100  			ys.Shrink()
   101  		}
   102  	}
   103  	c.clientLock.Unlock()
   104  	if !didSave {
   105  		client.Close()
   106  	}
   107  }
   108  
   109  // markForUse does all the bookkeeping required to ready a connection for use.
   110  func (c *Conn) markForUse() {
   111  	c.lastUsed = time.Now()
   112  	atomic.AddInt32(&c.refCount, 1)
   113  }
   114  
   115  // ConnPool is used to maintain a connection pool to other Consul
   116  // servers. This is used to reduce the latency of RPC requests between
   117  // servers. It is only used to pool connections in the rpcConsul mode.
   118  // Raft connections are pooled separately. Maintain at most one
   119  // connection per host, for up to MaxTime. When MaxTime connection
   120  // reaping is disabled. MaxStreams is used to control the number of idle
   121  // streams allowed. If TLS settings are provided outgoing connections
   122  // use TLS.
   123  type ConnPool struct {
   124  	// SrcAddr is the source address for outgoing connections.
   125  	SrcAddr *net.TCPAddr
   126  
   127  	// LogOutput is used to control logging
   128  	LogOutput io.Writer
   129  
   130  	// The maximum time to keep a connection open
   131  	MaxTime time.Duration
   132  
   133  	// The maximum number of open streams to keep
   134  	MaxStreams int
   135  
   136  	// TLS wrapper
   137  	TLSWrapper tlsutil.DCWrapper
   138  
   139  	// ForceTLS is used to enforce outgoing TLS verification
   140  	ForceTLS bool
   141  
   142  	sync.Mutex
   143  
   144  	// pool maps an address to a open connection
   145  	pool map[string]*Conn
   146  
   147  	// limiter is used to throttle the number of connect attempts
   148  	// to a given address. The first thread will attempt a connection
   149  	// and put a channel in here, which all other threads will wait
   150  	// on to close.
   151  	limiter map[string]chan struct{}
   152  
   153  	// Used to indicate the pool is shutdown
   154  	shutdown   bool
   155  	shutdownCh chan struct{}
   156  
   157  	// once initializes the internal data structures and connection
   158  	// reaping on first use.
   159  	once sync.Once
   160  }
   161  
   162  // init configures the initial data structures. It should be called
   163  // by p.once.Do(p.init) in all public methods.
   164  func (p *ConnPool) init() {
   165  	p.pool = make(map[string]*Conn)
   166  	p.limiter = make(map[string]chan struct{})
   167  	p.shutdownCh = make(chan struct{})
   168  	if p.MaxTime > 0 {
   169  		go p.reap()
   170  	}
   171  }
   172  
   173  // Shutdown is used to close the connection pool
   174  func (p *ConnPool) Shutdown() error {
   175  	p.once.Do(p.init)
   176  
   177  	p.Lock()
   178  	defer p.Unlock()
   179  
   180  	for _, conn := range p.pool {
   181  		conn.Close()
   182  	}
   183  	p.pool = make(map[string]*Conn)
   184  
   185  	if p.shutdown {
   186  		return nil
   187  	}
   188  	p.shutdown = true
   189  	close(p.shutdownCh)
   190  	return nil
   191  }
   192  
   193  // acquire will return a pooled connection, if available. Otherwise it will
   194  // wait for an existing connection attempt to finish, if one if in progress,
   195  // and will return that one if it succeeds. If all else fails, it will return a
   196  // newly-created connection and add it to the pool.
   197  func (p *ConnPool) acquire(dc string, addr net.Addr, version int, useTLS bool) (*Conn, error) {
   198  	addrStr := addr.String()
   199  
   200  	// Check to see if there's a pooled connection available. This is up
   201  	// here since it should the the vastly more common case than the rest
   202  	// of the code here.
   203  	p.Lock()
   204  	c := p.pool[addrStr]
   205  	if c != nil {
   206  		c.markForUse()
   207  		p.Unlock()
   208  		return c, nil
   209  	}
   210  
   211  	// If not (while we are still locked), set up the throttling structure
   212  	// for this address, which will make everyone else wait until our
   213  	// attempt is done.
   214  	var wait chan struct{}
   215  	var ok bool
   216  	if wait, ok = p.limiter[addrStr]; !ok {
   217  		wait = make(chan struct{})
   218  		p.limiter[addrStr] = wait
   219  	}
   220  	isLeadThread := !ok
   221  	p.Unlock()
   222  
   223  	// If we are the lead thread, make the new connection and then wake
   224  	// everybody else up to see if we got it.
   225  	if isLeadThread {
   226  		c, err := p.getNewConn(dc, addr, version, useTLS)
   227  		p.Lock()
   228  		delete(p.limiter, addrStr)
   229  		close(wait)
   230  		if err != nil {
   231  			p.Unlock()
   232  			return nil, err
   233  		}
   234  
   235  		p.pool[addrStr] = c
   236  		p.Unlock()
   237  		return c, nil
   238  	}
   239  
   240  	// Otherwise, wait for the lead thread to attempt the connection
   241  	// and use what's in the pool at that point.
   242  	select {
   243  	case <-p.shutdownCh:
   244  		return nil, fmt.Errorf("rpc error: shutdown")
   245  	case <-wait:
   246  	}
   247  
   248  	// See if the lead thread was able to get us a connection.
   249  	p.Lock()
   250  	if c := p.pool[addrStr]; c != nil {
   251  		c.markForUse()
   252  		p.Unlock()
   253  		return c, nil
   254  	}
   255  
   256  	p.Unlock()
   257  	return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
   258  }
   259  
   260  // HalfCloser is an interface that exposes a TCP half-close. We need this
   261  // because we want to expose the raw TCP connection underlying a TLS one in a
   262  // way that's hard to screw up and use for anything else. There's a change
   263  // brewing that will allow us to use the TLS connection for this instead -
   264  // https://go-review.googlesource.com/#/c/25159/.
   265  type HalfCloser interface {
   266  	CloseWrite() error
   267  }
   268  
   269  // DialTimeout is used to establish a raw connection to the given server, with a
   270  // given connection timeout.
   271  func (p *ConnPool) DialTimeout(dc string, addr net.Addr, timeout time.Duration, useTLS bool) (net.Conn, HalfCloser, error) {
   272  	p.once.Do(p.init)
   273  
   274  	// Try to dial the conn
   275  	d := &net.Dialer{LocalAddr: p.SrcAddr, Timeout: timeout}
   276  	conn, err := d.Dial("tcp", addr.String())
   277  	if err != nil {
   278  		return nil, nil, err
   279  	}
   280  
   281  	// Cast to TCPConn
   282  	var hc HalfCloser
   283  	if tcp, ok := conn.(*net.TCPConn); ok {
   284  		tcp.SetKeepAlive(true)
   285  		tcp.SetNoDelay(true)
   286  		hc = tcp
   287  	}
   288  
   289  	// Check if TLS is enabled
   290  	if (useTLS || p.ForceTLS) && p.TLSWrapper != nil {
   291  		// Switch the connection into TLS mode
   292  		if _, err := conn.Write([]byte{byte(RPCTLS)}); err != nil {
   293  			conn.Close()
   294  			return nil, nil, err
   295  		}
   296  
   297  		// Wrap the connection in a TLS client
   298  		tlsConn, err := p.TLSWrapper(dc, conn)
   299  		if err != nil {
   300  			conn.Close()
   301  			return nil, nil, err
   302  		}
   303  		conn = tlsConn
   304  	}
   305  
   306  	return conn, hc, nil
   307  }
   308  
   309  // getNewConn is used to return a new connection
   310  func (p *ConnPool) getNewConn(dc string, addr net.Addr, version int, useTLS bool) (*Conn, error) {
   311  	// Get a new, raw connection.
   312  	conn, _, err := p.DialTimeout(dc, addr, defaultDialTimeout, useTLS)
   313  	if err != nil {
   314  		return nil, err
   315  	}
   316  
   317  	// Switch the multiplexing based on version
   318  	var session muxSession
   319  	if version < 2 {
   320  		conn.Close()
   321  		return nil, fmt.Errorf("cannot make client connection, unsupported protocol version %d", version)
   322  	}
   323  
   324  	// Write the Consul multiplex byte to set the mode
   325  	if _, err := conn.Write([]byte{byte(RPCMultiplexV2)}); err != nil {
   326  		conn.Close()
   327  		return nil, err
   328  	}
   329  
   330  	// Setup the logger
   331  	conf := yamux.DefaultConfig()
   332  	conf.LogOutput = p.LogOutput
   333  
   334  	// Create a multiplexed session
   335  	session, _ = yamux.Client(conn, conf)
   336  
   337  	// Wrap the connection
   338  	c := &Conn{
   339  		refCount: 1,
   340  		addr:     addr,
   341  		session:  session,
   342  		clients:  list.New(),
   343  		lastUsed: time.Now(),
   344  		version:  version,
   345  		pool:     p,
   346  	}
   347  	return c, nil
   348  }
   349  
   350  // clearConn is used to clear any cached connection, potentially in response to an error
   351  func (p *ConnPool) clearConn(conn *Conn) {
   352  	// Ensure returned streams are closed
   353  	atomic.StoreInt32(&conn.shouldClose, 1)
   354  
   355  	// Clear from the cache
   356  	addrStr := conn.addr.String()
   357  	p.Lock()
   358  	if c, ok := p.pool[addrStr]; ok && c == conn {
   359  		delete(p.pool, addrStr)
   360  	}
   361  	p.Unlock()
   362  
   363  	// Close down immediately if idle
   364  	if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 {
   365  		conn.Close()
   366  	}
   367  }
   368  
   369  // releaseConn is invoked when we are done with a conn to reduce the ref count
   370  func (p *ConnPool) releaseConn(conn *Conn) {
   371  	refCount := atomic.AddInt32(&conn.refCount, -1)
   372  	if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 {
   373  		conn.Close()
   374  	}
   375  }
   376  
   377  // getClient is used to get a usable client for an address and protocol version
   378  func (p *ConnPool) getClient(dc string, addr net.Addr, version int, useTLS bool) (*Conn, *StreamClient, error) {
   379  	retries := 0
   380  START:
   381  	// Try to get a conn first
   382  	conn, err := p.acquire(dc, addr, version, useTLS)
   383  	if err != nil {
   384  		return nil, nil, fmt.Errorf("failed to get conn: %v", err)
   385  	}
   386  
   387  	// Get a client
   388  	client, err := conn.getClient()
   389  	if err != nil {
   390  		p.clearConn(conn)
   391  		p.releaseConn(conn)
   392  
   393  		// Try to redial, possible that the TCP session closed due to timeout
   394  		if retries == 0 {
   395  			retries++
   396  			goto START
   397  		}
   398  		return nil, nil, fmt.Errorf("failed to start stream: %v", err)
   399  	}
   400  	return conn, client, nil
   401  }
   402  
   403  // RPC is used to make an RPC call to a remote host
   404  func (p *ConnPool) RPC(dc string, addr net.Addr, version int, method string, useTLS bool, args interface{}, reply interface{}) error {
   405  	p.once.Do(p.init)
   406  
   407  	// Get a usable client
   408  	conn, sc, err := p.getClient(dc, addr, version, useTLS)
   409  	if err != nil {
   410  		return fmt.Errorf("rpc error getting client: %v", err)
   411  	}
   412  
   413  	// Make the RPC call
   414  	err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply)
   415  	if err != nil {
   416  		sc.Close()
   417  
   418  		// See the comment in leader_test.go TestLeader_ChangeServerID
   419  		// about how we found this. The tldr is that if we see this
   420  		// error, we know this connection is toast, so we should clear
   421  		// it and make a new one on the next attempt.
   422  		if lib.IsErrEOF(err) {
   423  			p.clearConn(conn)
   424  		}
   425  
   426  		p.releaseConn(conn)
   427  		return fmt.Errorf("rpc error making call: %v", err)
   428  	}
   429  
   430  	// Done with the connection
   431  	conn.returnClient(sc)
   432  	p.releaseConn(conn)
   433  	return nil
   434  }
   435  
   436  // Ping sends a Status.Ping message to the specified server and
   437  // returns true if healthy, false if an error occurred
   438  func (p *ConnPool) Ping(dc string, addr net.Addr, version int, useTLS bool) (bool, error) {
   439  	var out struct{}
   440  	err := p.RPC(dc, addr, version, "Status.Ping", useTLS, struct{}{}, &out)
   441  	return err == nil, err
   442  }
   443  
   444  // Reap is used to close conns open over maxTime
   445  func (p *ConnPool) reap() {
   446  	for {
   447  		// Sleep for a while
   448  		select {
   449  		case <-p.shutdownCh:
   450  			return
   451  		case <-time.After(time.Second):
   452  		}
   453  
   454  		// Reap all old conns
   455  		p.Lock()
   456  		var removed []string
   457  		now := time.Now()
   458  		for host, conn := range p.pool {
   459  			// Skip recently used connections
   460  			if now.Sub(conn.lastUsed) < p.MaxTime {
   461  				continue
   462  			}
   463  
   464  			// Skip connections with active streams
   465  			if atomic.LoadInt32(&conn.refCount) > 0 {
   466  				continue
   467  			}
   468  
   469  			// Close the conn
   470  			conn.Close()
   471  
   472  			// Remove from pool
   473  			removed = append(removed, host)
   474  		}
   475  		for _, host := range removed {
   476  			delete(p.pool, host)
   477  		}
   478  		p.Unlock()
   479  	}
   480  }