github.com/smithx10/nomad@v0.9.1-rc1/helper/pool/pool.go (about)

     1  package pool
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"net"
     9  	"net/rpc"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	hclog "github.com/hashicorp/go-hclog"
    15  	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
    16  	"github.com/hashicorp/nomad/helper/tlsutil"
    17  	"github.com/hashicorp/nomad/nomad/structs"
    18  	"github.com/hashicorp/yamux"
    19  )
    20  
    21  // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls.
    22  func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
    23  	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle)
    24  }
    25  
    26  // NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs.
    27  func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
    28  	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.HashiMsgpackHandle)
    29  }
    30  
    31  // streamClient is used to wrap a stream with an RPC client
    32  type StreamClient struct {
    33  	stream net.Conn
    34  	codec  rpc.ClientCodec
    35  }
    36  
    37  func (sc *StreamClient) Close() {
    38  	sc.stream.Close()
    39  }
    40  
    41  // Conn is a pooled connection to a Nomad server
    42  type Conn struct {
    43  	refCount    int32
    44  	shouldClose int32
    45  
    46  	addr     net.Addr
    47  	session  *yamux.Session
    48  	lastUsed time.Time
    49  	version  int
    50  
    51  	pool *ConnPool
    52  
    53  	clients    *list.List
    54  	clientLock sync.Mutex
    55  }
    56  
    57  // markForUse does all the bookkeeping required to ready a connection for use.
    58  func (c *Conn) markForUse() {
    59  	c.lastUsed = time.Now()
    60  	atomic.AddInt32(&c.refCount, 1)
    61  }
    62  
    63  func (c *Conn) Close() error {
    64  	return c.session.Close()
    65  }
    66  
    67  // getClient is used to get a cached or new client
    68  func (c *Conn) getClient() (*StreamClient, error) {
    69  	// Check for cached client
    70  	c.clientLock.Lock()
    71  	front := c.clients.Front()
    72  	if front != nil {
    73  		c.clients.Remove(front)
    74  	}
    75  	c.clientLock.Unlock()
    76  	if front != nil {
    77  		return front.Value.(*StreamClient), nil
    78  	}
    79  
    80  	// Open a new session
    81  	stream, err := c.session.Open()
    82  	if err != nil {
    83  		return nil, err
    84  	}
    85  
    86  	// Create a client codec
    87  	codec := NewClientCodec(stream)
    88  
    89  	// Return a new stream client
    90  	sc := &StreamClient{
    91  		stream: stream,
    92  		codec:  codec,
    93  	}
    94  	return sc, nil
    95  }
    96  
    97  // returnClient is used when done with a stream
    98  // to allow re-use by a future RPC
    99  func (c *Conn) returnClient(client *StreamClient) {
   100  	didSave := false
   101  	c.clientLock.Lock()
   102  	if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 {
   103  		c.clients.PushFront(client)
   104  		didSave = true
   105  
   106  		// If this is a Yamux stream, shrink the internal buffers so that
   107  		// we can GC the idle memory
   108  		if ys, ok := client.stream.(*yamux.Stream); ok {
   109  			ys.Shrink()
   110  		}
   111  	}
   112  	c.clientLock.Unlock()
   113  	if !didSave {
   114  		client.Close()
   115  	}
   116  }
   117  
   118  // ConnPool is used to maintain a connection pool to other
   119  // Nomad servers. This is used to reduce the latency of
   120  // RPC requests between servers. It is only used to pool
   121  // connections in the rpcNomad mode. Raft connections
   122  // are pooled separately.
   123  type ConnPool struct {
   124  	sync.Mutex
   125  
   126  	// logger is the logger to be used
   127  	logger *log.Logger
   128  
   129  	// The maximum time to keep a connection open
   130  	maxTime time.Duration
   131  
   132  	// The maximum number of open streams to keep
   133  	maxStreams int
   134  
   135  	// Pool maps an address to a open connection
   136  	pool map[string]*Conn
   137  
   138  	// limiter is used to throttle the number of connect attempts
   139  	// to a given address. The first thread will attempt a connection
   140  	// and put a channel in here, which all other threads will wait
   141  	// on to close.
   142  	limiter map[string]chan struct{}
   143  
   144  	// TLS wrapper
   145  	tlsWrap tlsutil.RegionWrapper
   146  
   147  	// Used to indicate the pool is shutdown
   148  	shutdown   bool
   149  	shutdownCh chan struct{}
   150  
   151  	// connListener is used to notify a potential listener of a new connection
   152  	// being made.
   153  	connListener chan<- *yamux.Session
   154  }
   155  
   156  // NewPool is used to make a new connection pool
   157  // Maintain at most one connection per host, for up to maxTime.
   158  // Set maxTime to 0 to disable reaping. maxStreams is used to control
   159  // the number of idle streams allowed.
   160  // If TLS settings are provided outgoing connections use TLS.
   161  func NewPool(logger hclog.Logger, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool {
   162  	pool := &ConnPool{
   163  		logger:     logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}),
   164  		maxTime:    maxTime,
   165  		maxStreams: maxStreams,
   166  		pool:       make(map[string]*Conn),
   167  		limiter:    make(map[string]chan struct{}),
   168  		tlsWrap:    tlsWrap,
   169  		shutdownCh: make(chan struct{}),
   170  	}
   171  	if maxTime > 0 {
   172  		go pool.reap()
   173  	}
   174  	return pool
   175  }
   176  
   177  // Shutdown is used to close the connection pool
   178  func (p *ConnPool) Shutdown() error {
   179  	p.Lock()
   180  	defer p.Unlock()
   181  
   182  	for _, conn := range p.pool {
   183  		conn.Close()
   184  	}
   185  	p.pool = make(map[string]*Conn)
   186  
   187  	if p.shutdown {
   188  		return nil
   189  	}
   190  
   191  	if p.connListener != nil {
   192  		close(p.connListener)
   193  		p.connListener = nil
   194  	}
   195  
   196  	p.shutdown = true
   197  	close(p.shutdownCh)
   198  	return nil
   199  }
   200  
   201  // ReloadTLS reloads TLS configuration on the fly
   202  func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) {
   203  	p.Lock()
   204  	defer p.Unlock()
   205  
   206  	oldPool := p.pool
   207  	for _, conn := range oldPool {
   208  		conn.Close()
   209  	}
   210  	p.pool = make(map[string]*Conn)
   211  	p.tlsWrap = tlsWrap
   212  }
   213  
   214  // SetConnListener is used to listen to new connections being made. The
   215  // channel will be closed when the conn pool is closed or a new listener is set.
   216  func (p *ConnPool) SetConnListener(l chan<- *yamux.Session) {
   217  	p.Lock()
   218  	defer p.Unlock()
   219  
   220  	// Close the old listener
   221  	if p.connListener != nil {
   222  		close(p.connListener)
   223  	}
   224  
   225  	// Store the new listener
   226  	p.connListener = l
   227  }
   228  
   229  // Acquire is used to get a connection that is
   230  // pooled or to return a new connection
   231  func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) {
   232  	// Check to see if there's a pooled connection available. This is up
   233  	// here since it should the vastly more common case than the rest
   234  	// of the code here.
   235  	p.Lock()
   236  	c := p.pool[addr.String()]
   237  	if c != nil {
   238  		c.markForUse()
   239  		p.Unlock()
   240  		return c, nil
   241  	}
   242  
   243  	// If not (while we are still locked), set up the throttling structure
   244  	// for this address, which will make everyone else wait until our
   245  	// attempt is done.
   246  	var wait chan struct{}
   247  	var ok bool
   248  	if wait, ok = p.limiter[addr.String()]; !ok {
   249  		wait = make(chan struct{})
   250  		p.limiter[addr.String()] = wait
   251  	}
   252  	isLeadThread := !ok
   253  	p.Unlock()
   254  
   255  	// If we are the lead thread, make the new connection and then wake
   256  	// everybody else up to see if we got it.
   257  	if isLeadThread {
   258  		c, err := p.getNewConn(region, addr, version)
   259  		p.Lock()
   260  		delete(p.limiter, addr.String())
   261  		close(wait)
   262  		if err != nil {
   263  			p.Unlock()
   264  			return nil, err
   265  		}
   266  
   267  		p.pool[addr.String()] = c
   268  
   269  		// If there is a connection listener, notify them of the new connection.
   270  		if p.connListener != nil {
   271  			select {
   272  			case p.connListener <- c.session:
   273  			default:
   274  			}
   275  		}
   276  
   277  		p.Unlock()
   278  		return c, nil
   279  	}
   280  
   281  	// Otherwise, wait for the lead thread to attempt the connection
   282  	// and use what's in the pool at that point.
   283  	select {
   284  	case <-p.shutdownCh:
   285  		return nil, fmt.Errorf("rpc error: shutdown")
   286  	case <-wait:
   287  	}
   288  
   289  	// See if the lead thread was able to get us a connection.
   290  	p.Lock()
   291  	if c := p.pool[addr.String()]; c != nil {
   292  		c.markForUse()
   293  		p.Unlock()
   294  		return c, nil
   295  	}
   296  
   297  	p.Unlock()
   298  	return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
   299  }
   300  
   301  // getNewConn is used to return a new connection
   302  func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, error) {
   303  	// Try to dial the conn
   304  	conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
   305  	if err != nil {
   306  		return nil, err
   307  	}
   308  
   309  	// Cast to TCPConn
   310  	if tcp, ok := conn.(*net.TCPConn); ok {
   311  		tcp.SetKeepAlive(true)
   312  		tcp.SetNoDelay(true)
   313  	}
   314  
   315  	// Check if TLS is enabled
   316  	if p.tlsWrap != nil {
   317  		// Switch the connection into TLS mode
   318  		if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil {
   319  			conn.Close()
   320  			return nil, err
   321  		}
   322  
   323  		// Wrap the connection in a TLS client
   324  		tlsConn, err := p.tlsWrap(region, conn)
   325  		if err != nil {
   326  			conn.Close()
   327  			return nil, err
   328  		}
   329  		conn = tlsConn
   330  	}
   331  
   332  	// Write the multiplex byte to set the mode
   333  	if _, err := conn.Write([]byte{byte(RpcMultiplex)}); err != nil {
   334  		conn.Close()
   335  		return nil, err
   336  	}
   337  
   338  	// Setup the logger
   339  	conf := yamux.DefaultConfig()
   340  	conf.LogOutput = nil
   341  	conf.Logger = p.logger
   342  
   343  	// Create a multiplexed session
   344  	session, err := yamux.Client(conn, conf)
   345  	if err != nil {
   346  		conn.Close()
   347  		return nil, err
   348  	}
   349  
   350  	// Wrap the connection
   351  	c := &Conn{
   352  		refCount: 1,
   353  		addr:     addr,
   354  		session:  session,
   355  		clients:  list.New(),
   356  		lastUsed: time.Now(),
   357  		version:  version,
   358  		pool:     p,
   359  	}
   360  	return c, nil
   361  }
   362  
   363  // clearConn is used to clear any cached connection, potentially in response to
   364  // an error
   365  func (p *ConnPool) clearConn(conn *Conn) {
   366  	// Ensure returned streams are closed
   367  	atomic.StoreInt32(&conn.shouldClose, 1)
   368  
   369  	// Clear from the cache
   370  	p.Lock()
   371  	if c, ok := p.pool[conn.addr.String()]; ok && c == conn {
   372  		delete(p.pool, conn.addr.String())
   373  	}
   374  	p.Unlock()
   375  
   376  	// Close down immediately if idle
   377  	if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 {
   378  		conn.Close()
   379  	}
   380  }
   381  
   382  // releaseConn is invoked when we are done with a conn to reduce the ref count
   383  func (p *ConnPool) releaseConn(conn *Conn) {
   384  	refCount := atomic.AddInt32(&conn.refCount, -1)
   385  	if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 {
   386  		conn.Close()
   387  	}
   388  }
   389  
   390  // getClient is used to get a usable client for an address and protocol version
   391  func (p *ConnPool) getClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) {
   392  	retries := 0
   393  START:
   394  	// Try to get a conn first
   395  	conn, err := p.acquire(region, addr, version)
   396  	if err != nil {
   397  		return nil, nil, fmt.Errorf("failed to get conn: %v", err)
   398  	}
   399  
   400  	// Get a client
   401  	client, err := conn.getClient()
   402  	if err != nil {
   403  		p.clearConn(conn)
   404  		p.releaseConn(conn)
   405  
   406  		// Try to redial, possible that the TCP session closed due to timeout
   407  		if retries == 0 {
   408  			retries++
   409  			goto START
   410  		}
   411  		return nil, nil, fmt.Errorf("failed to start stream: %v", err)
   412  	}
   413  	return conn, client, nil
   414  }
   415  
   416  // RPC is used to make an RPC call to a remote host
   417  func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
   418  	// Get a usable client
   419  	conn, sc, err := p.getClient(region, addr, version)
   420  	if err != nil {
   421  		return fmt.Errorf("rpc error: %v", err)
   422  	}
   423  
   424  	// Make the RPC call
   425  	err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply)
   426  	if err != nil {
   427  		sc.Close()
   428  		p.releaseConn(conn)
   429  		return fmt.Errorf("rpc error: %v", err)
   430  	}
   431  
   432  	// Done with the connection
   433  	conn.returnClient(sc)
   434  	p.releaseConn(conn)
   435  	return nil
   436  }
   437  
   438  // Reap is used to close conns open over maxTime
   439  func (p *ConnPool) reap() {
   440  	for {
   441  		// Sleep for a while
   442  		select {
   443  		case <-p.shutdownCh:
   444  			return
   445  		case <-time.After(time.Second):
   446  		}
   447  
   448  		// Reap all old conns
   449  		p.Lock()
   450  		var removed []string
   451  		now := time.Now()
   452  		for host, conn := range p.pool {
   453  			// Skip recently used connections
   454  			if now.Sub(conn.lastUsed) < p.maxTime {
   455  				continue
   456  			}
   457  
   458  			// Skip connections with active streams
   459  			if atomic.LoadInt32(&conn.refCount) > 0 {
   460  				continue
   461  			}
   462  
   463  			// Close the conn
   464  			conn.Close()
   465  
   466  			// Remove from pool
   467  			removed = append(removed, host)
   468  		}
   469  		for _, host := range removed {
   470  			delete(p.pool, host)
   471  		}
   472  		p.Unlock()
   473  	}
   474  }