github.com/hspak/nomad@v0.7.2-0.20180309000617-bc4ae22a39a5/helper/pool/pool.go (about)

     1  package pool
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"net/rpc"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
    14  	"github.com/hashicorp/nomad/helper/tlsutil"
    15  	"github.com/hashicorp/nomad/nomad/structs"
    16  	"github.com/hashicorp/yamux"
    17  )
    18  
    19  // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls.
    20  func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
    21  	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle)
    22  }
    23  
    24  // NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs.
    25  func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
    26  	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle)
    27  }
    28  
    29  // streamClient is used to wrap a stream with an RPC client
    30  type StreamClient struct {
    31  	stream net.Conn
    32  	codec  rpc.ClientCodec
    33  }
    34  
    35  func (sc *StreamClient) Close() {
    36  	sc.stream.Close()
    37  }
    38  
    39  // Conn is a pooled connection to a Nomad server
    40  type Conn struct {
    41  	refCount    int32
    42  	shouldClose int32
    43  
    44  	addr     net.Addr
    45  	session  *yamux.Session
    46  	lastUsed time.Time
    47  	version  int
    48  
    49  	pool *ConnPool
    50  
    51  	clients    *list.List
    52  	clientLock sync.Mutex
    53  }
    54  
    55  // markForUse does all the bookkeeping required to ready a connection for use.
    56  func (c *Conn) markForUse() {
    57  	c.lastUsed = time.Now()
    58  	atomic.AddInt32(&c.refCount, 1)
    59  }
    60  
    61  func (c *Conn) Close() error {
    62  	return c.session.Close()
    63  }
    64  
    65  // getClient is used to get a cached or new client
    66  func (c *Conn) getClient() (*StreamClient, error) {
    67  	// Check for cached client
    68  	c.clientLock.Lock()
    69  	front := c.clients.Front()
    70  	if front != nil {
    71  		c.clients.Remove(front)
    72  	}
    73  	c.clientLock.Unlock()
    74  	if front != nil {
    75  		return front.Value.(*StreamClient), nil
    76  	}
    77  
    78  	// Open a new session
    79  	stream, err := c.session.Open()
    80  	if err != nil {
    81  		return nil, err
    82  	}
    83  
    84  	// Create a client codec
    85  	codec := NewClientCodec(stream)
    86  
    87  	// Return a new stream client
    88  	sc := &StreamClient{
    89  		stream: stream,
    90  		codec:  codec,
    91  	}
    92  	return sc, nil
    93  }
    94  
    95  // returnClient is used when done with a stream
    96  // to allow re-use by a future RPC
    97  func (c *Conn) returnClient(client *StreamClient) {
    98  	didSave := false
    99  	c.clientLock.Lock()
   100  	if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 {
   101  		c.clients.PushFront(client)
   102  		didSave = true
   103  
   104  		// If this is a Yamux stream, shrink the internal buffers so that
   105  		// we can GC the idle memory
   106  		if ys, ok := client.stream.(*yamux.Stream); ok {
   107  			ys.Shrink()
   108  		}
   109  	}
   110  	c.clientLock.Unlock()
   111  	if !didSave {
   112  		client.Close()
   113  	}
   114  }
   115  
   116  // ConnPool is used to maintain a connection pool to other
   117  // Nomad servers. This is used to reduce the latency of
   118  // RPC requests between servers. It is only used to pool
   119  // connections in the rpcNomad mode. Raft connections
   120  // are pooled separately.
   121  type ConnPool struct {
   122  	sync.Mutex
   123  
   124  	// LogOutput is used to control logging
   125  	logOutput io.Writer
   126  
   127  	// The maximum time to keep a connection open
   128  	maxTime time.Duration
   129  
   130  	// The maximum number of open streams to keep
   131  	maxStreams int
   132  
   133  	// Pool maps an address to a open connection
   134  	pool map[string]*Conn
   135  
   136  	// limiter is used to throttle the number of connect attempts
   137  	// to a given address. The first thread will attempt a connection
   138  	// and put a channel in here, which all other threads will wait
   139  	// on to close.
   140  	limiter map[string]chan struct{}
   141  
   142  	// TLS wrapper
   143  	tlsWrap tlsutil.RegionWrapper
   144  
   145  	// Used to indicate the pool is shutdown
   146  	shutdown   bool
   147  	shutdownCh chan struct{}
   148  
   149  	// connListener is used to notify a potential listener of a new connection
   150  	// being made.
   151  	connListener chan<- *yamux.Session
   152  }
   153  
   154  // NewPool is used to make a new connection pool
   155  // Maintain at most one connection per host, for up to maxTime.
   156  // Set maxTime to 0 to disable reaping. maxStreams is used to control
   157  // the number of idle streams allowed.
   158  // If TLS settings are provided outgoing connections use TLS.
   159  func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool {
   160  	pool := &ConnPool{
   161  		logOutput:  logOutput,
   162  		maxTime:    maxTime,
   163  		maxStreams: maxStreams,
   164  		pool:       make(map[string]*Conn),
   165  		limiter:    make(map[string]chan struct{}),
   166  		tlsWrap:    tlsWrap,
   167  		shutdownCh: make(chan struct{}),
   168  	}
   169  	if maxTime > 0 {
   170  		go pool.reap()
   171  	}
   172  	return pool
   173  }
   174  
   175  // Shutdown is used to close the connection pool
   176  func (p *ConnPool) Shutdown() error {
   177  	p.Lock()
   178  	defer p.Unlock()
   179  
   180  	for _, conn := range p.pool {
   181  		conn.Close()
   182  	}
   183  	p.pool = make(map[string]*Conn)
   184  
   185  	if p.shutdown {
   186  		return nil
   187  	}
   188  
   189  	if p.connListener != nil {
   190  		close(p.connListener)
   191  		p.connListener = nil
   192  	}
   193  
   194  	p.shutdown = true
   195  	close(p.shutdownCh)
   196  	return nil
   197  }
   198  
   199  // ReloadTLS reloads TLS configuration on the fly
   200  func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) {
   201  	p.Lock()
   202  	defer p.Unlock()
   203  
   204  	oldPool := p.pool
   205  	for _, conn := range oldPool {
   206  		conn.Close()
   207  	}
   208  	p.pool = make(map[string]*Conn)
   209  	p.tlsWrap = tlsWrap
   210  }
   211  
   212  // SetConnListener is used to listen to new connections being made. The
   213  // channel will be closed when the conn pool is closed or a new listener is set.
   214  func (p *ConnPool) SetConnListener(l chan<- *yamux.Session) {
   215  	p.Lock()
   216  	defer p.Unlock()
   217  
   218  	// Close the old listener
   219  	if p.connListener != nil {
   220  		close(p.connListener)
   221  	}
   222  
   223  	// Store the new listener
   224  	p.connListener = l
   225  }
   226  
   227  // Acquire is used to get a connection that is
   228  // pooled or to return a new connection
   229  func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) {
   230  	// Check to see if there's a pooled connection available. This is up
   231  	// here since it should the vastly more common case than the rest
   232  	// of the code here.
   233  	p.Lock()
   234  	c := p.pool[addr.String()]
   235  	if c != nil {
   236  		c.markForUse()
   237  		p.Unlock()
   238  		return c, nil
   239  	}
   240  
   241  	// If not (while we are still locked), set up the throttling structure
   242  	// for this address, which will make everyone else wait until our
   243  	// attempt is done.
   244  	var wait chan struct{}
   245  	var ok bool
   246  	if wait, ok = p.limiter[addr.String()]; !ok {
   247  		wait = make(chan struct{})
   248  		p.limiter[addr.String()] = wait
   249  	}
   250  	isLeadThread := !ok
   251  	p.Unlock()
   252  
   253  	// If we are the lead thread, make the new connection and then wake
   254  	// everybody else up to see if we got it.
   255  	if isLeadThread {
   256  		c, err := p.getNewConn(region, addr, version)
   257  		p.Lock()
   258  		delete(p.limiter, addr.String())
   259  		close(wait)
   260  		if err != nil {
   261  			p.Unlock()
   262  			return nil, err
   263  		}
   264  
   265  		p.pool[addr.String()] = c
   266  
   267  		// If there is a connection listener, notify them of the new connection.
   268  		if p.connListener != nil {
   269  			select {
   270  			case p.connListener <- c.session:
   271  			default:
   272  			}
   273  		}
   274  
   275  		p.Unlock()
   276  		return c, nil
   277  	}
   278  
   279  	// Otherwise, wait for the lead thread to attempt the connection
   280  	// and use what's in the pool at that point.
   281  	select {
   282  	case <-p.shutdownCh:
   283  		return nil, fmt.Errorf("rpc error: shutdown")
   284  	case <-wait:
   285  	}
   286  
   287  	// See if the lead thread was able to get us a connection.
   288  	p.Lock()
   289  	if c := p.pool[addr.String()]; c != nil {
   290  		c.markForUse()
   291  		p.Unlock()
   292  		return c, nil
   293  	}
   294  
   295  	p.Unlock()
   296  	return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
   297  }
   298  
   299  // getNewConn is used to return a new connection
   300  func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, error) {
   301  	// Try to dial the conn
   302  	conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
   303  	if err != nil {
   304  		return nil, err
   305  	}
   306  
   307  	// Cast to TCPConn
   308  	if tcp, ok := conn.(*net.TCPConn); ok {
   309  		tcp.SetKeepAlive(true)
   310  		tcp.SetNoDelay(true)
   311  	}
   312  
   313  	// Check if TLS is enabled
   314  	if p.tlsWrap != nil {
   315  		// Switch the connection into TLS mode
   316  		if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil {
   317  			conn.Close()
   318  			return nil, err
   319  		}
   320  
   321  		// Wrap the connection in a TLS client
   322  		tlsConn, err := p.tlsWrap(region, conn)
   323  		if err != nil {
   324  			conn.Close()
   325  			return nil, err
   326  		}
   327  		conn = tlsConn
   328  	}
   329  
   330  	// Write the multiplex byte to set the mode
   331  	if _, err := conn.Write([]byte{byte(RpcMultiplex)}); err != nil {
   332  		conn.Close()
   333  		return nil, err
   334  	}
   335  
   336  	// Setup the logger
   337  	conf := yamux.DefaultConfig()
   338  	conf.LogOutput = p.logOutput
   339  
   340  	// Create a multiplexed session
   341  	session, err := yamux.Client(conn, conf)
   342  	if err != nil {
   343  		conn.Close()
   344  		return nil, err
   345  	}
   346  
   347  	// Wrap the connection
   348  	c := &Conn{
   349  		refCount: 1,
   350  		addr:     addr,
   351  		session:  session,
   352  		clients:  list.New(),
   353  		lastUsed: time.Now(),
   354  		version:  version,
   355  		pool:     p,
   356  	}
   357  	return c, nil
   358  }
   359  
   360  // clearConn is used to clear any cached connection, potentially in response to
   361  // an error
   362  func (p *ConnPool) clearConn(conn *Conn) {
   363  	// Ensure returned streams are closed
   364  	atomic.StoreInt32(&conn.shouldClose, 1)
   365  
   366  	// Clear from the cache
   367  	p.Lock()
   368  	if c, ok := p.pool[conn.addr.String()]; ok && c == conn {
   369  		delete(p.pool, conn.addr.String())
   370  	}
   371  	p.Unlock()
   372  
   373  	// Close down immediately if idle
   374  	if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 {
   375  		conn.Close()
   376  	}
   377  }
   378  
   379  // releaseConn is invoked when we are done with a conn to reduce the ref count
   380  func (p *ConnPool) releaseConn(conn *Conn) {
   381  	refCount := atomic.AddInt32(&conn.refCount, -1)
   382  	if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 {
   383  		conn.Close()
   384  	}
   385  }
   386  
   387  // getClient is used to get a usable client for an address and protocol version
   388  func (p *ConnPool) getClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) {
   389  	retries := 0
   390  START:
   391  	// Try to get a conn first
   392  	conn, err := p.acquire(region, addr, version)
   393  	if err != nil {
   394  		return nil, nil, fmt.Errorf("failed to get conn: %v", err)
   395  	}
   396  
   397  	// Get a client
   398  	client, err := conn.getClient()
   399  	if err != nil {
   400  		p.clearConn(conn)
   401  		p.releaseConn(conn)
   402  
   403  		// Try to redial, possible that the TCP session closed due to timeout
   404  		if retries == 0 {
   405  			retries++
   406  			goto START
   407  		}
   408  		return nil, nil, fmt.Errorf("failed to start stream: %v", err)
   409  	}
   410  	return conn, client, nil
   411  }
   412  
   413  // RPC is used to make an RPC call to a remote host
   414  func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
   415  	// Get a usable client
   416  	conn, sc, err := p.getClient(region, addr, version)
   417  	if err != nil {
   418  		return fmt.Errorf("rpc error: %v", err)
   419  	}
   420  
   421  	// Make the RPC call
   422  	err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply)
   423  	if err != nil {
   424  		sc.Close()
   425  		p.releaseConn(conn)
   426  		return fmt.Errorf("rpc error: %v", err)
   427  	}
   428  
   429  	// Done with the connection
   430  	conn.returnClient(sc)
   431  	p.releaseConn(conn)
   432  	return nil
   433  }
   434  
   435  // Reap is used to close conns open over maxTime
   436  func (p *ConnPool) reap() {
   437  	for {
   438  		// Sleep for a while
   439  		select {
   440  		case <-p.shutdownCh:
   441  			return
   442  		case <-time.After(time.Second):
   443  		}
   444  
   445  		// Reap all old conns
   446  		p.Lock()
   447  		var removed []string
   448  		now := time.Now()
   449  		for host, conn := range p.pool {
   450  			// Skip recently used connections
   451  			if now.Sub(conn.lastUsed) < p.maxTime {
   452  				continue
   453  			}
   454  
   455  			// Skip connections with active streams
   456  			if atomic.LoadInt32(&conn.refCount) > 0 {
   457  				continue
   458  			}
   459  
   460  			// Close the conn
   461  			conn.Close()
   462  
   463  			// Remove from pool
   464  			removed = append(removed, host)
   465  		}
   466  		for _, host := range removed {
   467  			delete(p.pool, host)
   468  		}
   469  		p.Unlock()
   470  	}
   471  }