github.com/diptanu/nomad@v0.5.7-0.20170516172507-d72e86cbe3d9/nomad/pool.go (about)

     1  package nomad
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"net/rpc"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/hashicorp/net-rpc-msgpackrpc"
    14  	"github.com/hashicorp/nomad/helper/tlsutil"
    15  	"github.com/hashicorp/yamux"
    16  )
    17  
    18  // streamClient is used to wrap a stream with an RPC client
    19  type StreamClient struct {
    20  	stream net.Conn
    21  	codec  rpc.ClientCodec
    22  }
    23  
    24  func (sc *StreamClient) Close() {
    25  	sc.stream.Close()
    26  }
    27  
    28  // Conn is a pooled connection to a Nomad server
    29  type Conn struct {
    30  	refCount    int32
    31  	shouldClose int32
    32  
    33  	addr     net.Addr
    34  	session  *yamux.Session
    35  	lastUsed time.Time
    36  	version  int
    37  
    38  	pool *ConnPool
    39  
    40  	clients    *list.List
    41  	clientLock sync.Mutex
    42  }
    43  
    44  // markForUse does all the bookkeeping required to ready a connection for use.
    45  func (c *Conn) markForUse() {
    46  	c.lastUsed = time.Now()
    47  	atomic.AddInt32(&c.refCount, 1)
    48  }
    49  
    50  func (c *Conn) Close() error {
    51  	return c.session.Close()
    52  }
    53  
    54  // getClient is used to get a cached or new client
    55  func (c *Conn) getClient() (*StreamClient, error) {
    56  	// Check for cached client
    57  	c.clientLock.Lock()
    58  	front := c.clients.Front()
    59  	if front != nil {
    60  		c.clients.Remove(front)
    61  	}
    62  	c.clientLock.Unlock()
    63  	if front != nil {
    64  		return front.Value.(*StreamClient), nil
    65  	}
    66  
    67  	// Open a new session
    68  	stream, err := c.session.Open()
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	// Create a client codec
    74  	codec := NewClientCodec(stream)
    75  
    76  	// Return a new stream client
    77  	sc := &StreamClient{
    78  		stream: stream,
    79  		codec:  codec,
    80  	}
    81  	return sc, nil
    82  }
    83  
    84  // returnStream is used when done with a stream
    85  // to allow re-use by a future RPC
    86  func (c *Conn) returnClient(client *StreamClient) {
    87  	didSave := false
    88  	c.clientLock.Lock()
    89  	if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 {
    90  		c.clients.PushFront(client)
    91  		didSave = true
    92  
    93  		// If this is a Yamux stream, shrink the internal buffers so that
    94  		// we can GC the idle memory
    95  		if ys, ok := client.stream.(*yamux.Stream); ok {
    96  			ys.Shrink()
    97  		}
    98  	}
    99  	c.clientLock.Unlock()
   100  	if !didSave {
   101  		client.Close()
   102  	}
   103  }
   104  
   105  // ConnPool is used to maintain a connection pool to other
   106  // Nomad servers. This is used to reduce the latency of
   107  // RPC requests between servers. It is only used to pool
   108  // connections in the rpcNomad mode. Raft connections
   109  // are pooled separately.
   110  type ConnPool struct {
   111  	sync.Mutex
   112  
   113  	// LogOutput is used to control logging
   114  	logOutput io.Writer
   115  
   116  	// The maximum time to keep a connection open
   117  	maxTime time.Duration
   118  
   119  	// The maximum number of open streams to keep
   120  	maxStreams int
   121  
   122  	// Pool maps an address to a open connection
   123  	pool map[string]*Conn
   124  
   125  	// limiter is used to throttle the number of connect attempts
   126  	// to a given address. The first thread will attempt a connection
   127  	// and put a channel in here, which all other threads will wait
   128  	// on to close.
   129  	limiter map[string]chan struct{}
   130  
   131  	// TLS wrapper
   132  	tlsWrap tlsutil.RegionWrapper
   133  
   134  	// Used to indicate the pool is shutdown
   135  	shutdown   bool
   136  	shutdownCh chan struct{}
   137  }
   138  
   139  // NewPool is used to make a new connection pool
   140  // Maintain at most one connection per host, for up to maxTime.
   141  // Set maxTime to 0 to disable reaping. maxStreams is used to control
   142  // the number of idle streams allowed.
   143  // If TLS settings are provided outgoing connections use TLS.
   144  func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool {
   145  	pool := &ConnPool{
   146  		logOutput:  logOutput,
   147  		maxTime:    maxTime,
   148  		maxStreams: maxStreams,
   149  		pool:       make(map[string]*Conn),
   150  		limiter:    make(map[string]chan struct{}),
   151  		tlsWrap:    tlsWrap,
   152  		shutdownCh: make(chan struct{}),
   153  	}
   154  	if maxTime > 0 {
   155  		go pool.reap()
   156  	}
   157  	return pool
   158  }
   159  
   160  // Shutdown is used to close the connection pool
   161  func (p *ConnPool) Shutdown() error {
   162  	p.Lock()
   163  	defer p.Unlock()
   164  
   165  	for _, conn := range p.pool {
   166  		conn.Close()
   167  	}
   168  	p.pool = make(map[string]*Conn)
   169  
   170  	if p.shutdown {
   171  		return nil
   172  	}
   173  	p.shutdown = true
   174  	close(p.shutdownCh)
   175  	return nil
   176  }
   177  
   178  // Acquire is used to get a connection that is
   179  // pooled or to return a new connection
   180  func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) {
   181  	// Check to see if there's a pooled connection available. This is up
   182  	// here since it should the vastly more common case than the rest
   183  	// of the code here.
   184  	p.Lock()
   185  	c := p.pool[addr.String()]
   186  	if c != nil {
   187  		c.markForUse()
   188  		p.Unlock()
   189  		return c, nil
   190  	}
   191  
   192  	// If not (while we are still locked), set up the throttling structure
   193  	// for this address, which will make everyone else wait until our
   194  	// attempt is done.
   195  	var wait chan struct{}
   196  	var ok bool
   197  	if wait, ok = p.limiter[addr.String()]; !ok {
   198  		wait = make(chan struct{})
   199  		p.limiter[addr.String()] = wait
   200  	}
   201  	isLeadThread := !ok
   202  	p.Unlock()
   203  
   204  	// If we are the lead thread, make the new connection and then wake
   205  	// everybody else up to see if we got it.
   206  	if isLeadThread {
   207  		c, err := p.getNewConn(region, addr, version)
   208  		p.Lock()
   209  		delete(p.limiter, addr.String())
   210  		close(wait)
   211  		if err != nil {
   212  			p.Unlock()
   213  			return nil, err
   214  		}
   215  
   216  		p.pool[addr.String()] = c
   217  		p.Unlock()
   218  		return c, nil
   219  	}
   220  
   221  	// Otherwise, wait for the lead thread to attempt the connection
   222  	// and use what's in the pool at that point.
   223  	select {
   224  	case <-p.shutdownCh:
   225  		return nil, fmt.Errorf("rpc error: shutdown")
   226  	case <-wait:
   227  	}
   228  
   229  	// See if the lead thread was able to get us a connection.
   230  	p.Lock()
   231  	if c := p.pool[addr.String()]; c != nil {
   232  		c.markForUse()
   233  		p.Unlock()
   234  		return c, nil
   235  	}
   236  
   237  	p.Unlock()
   238  	return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
   239  }
   240  
   241  // getNewConn is used to return a new connection
   242  func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, error) {
   243  	// Try to dial the conn
   244  	conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
   245  	if err != nil {
   246  		return nil, err
   247  	}
   248  
   249  	// Cast to TCPConn
   250  	if tcp, ok := conn.(*net.TCPConn); ok {
   251  		tcp.SetKeepAlive(true)
   252  		tcp.SetNoDelay(true)
   253  	}
   254  
   255  	// Check if TLS is enabled
   256  	if p.tlsWrap != nil {
   257  		// Switch the connection into TLS mode
   258  		if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
   259  			conn.Close()
   260  			return nil, err
   261  		}
   262  
   263  		// Wrap the connection in a TLS client
   264  		tlsConn, err := p.tlsWrap(region, conn)
   265  		if err != nil {
   266  			conn.Close()
   267  			return nil, err
   268  		}
   269  		conn = tlsConn
   270  	}
   271  
   272  	// Write the multiplex byte to set the mode
   273  	if _, err := conn.Write([]byte{byte(rpcMultiplex)}); err != nil {
   274  		conn.Close()
   275  		return nil, err
   276  	}
   277  
   278  	// Setup the logger
   279  	conf := yamux.DefaultConfig()
   280  	conf.LogOutput = p.logOutput
   281  
   282  	// Create a multiplexed session
   283  	session, err := yamux.Client(conn, conf)
   284  	if err != nil {
   285  		conn.Close()
   286  		return nil, err
   287  	}
   288  
   289  	// Wrap the connection
   290  	c := &Conn{
   291  		refCount: 1,
   292  		addr:     addr,
   293  		session:  session,
   294  		clients:  list.New(),
   295  		lastUsed: time.Now(),
   296  		version:  version,
   297  		pool:     p,
   298  	}
   299  	return c, nil
   300  }
   301  
   302  // clearConn is used to clear any cached connection, potentially in response to an erro
   303  func (p *ConnPool) clearConn(conn *Conn) {
   304  	// Ensure returned streams are closed
   305  	atomic.StoreInt32(&conn.shouldClose, 1)
   306  
   307  	// Clear from the cache
   308  	p.Lock()
   309  	if c, ok := p.pool[conn.addr.String()]; ok && c == conn {
   310  		delete(p.pool, conn.addr.String())
   311  	}
   312  	p.Unlock()
   313  
   314  	// Close down immediately if idle
   315  	if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 {
   316  		conn.Close()
   317  	}
   318  }
   319  
   320  // releaseConn is invoked when we are done with a conn to reduce the ref count
   321  func (p *ConnPool) releaseConn(conn *Conn) {
   322  	refCount := atomic.AddInt32(&conn.refCount, -1)
   323  	if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 {
   324  		conn.Close()
   325  	}
   326  }
   327  
   328  // getClient is used to get a usable client for an address and protocol version
   329  func (p *ConnPool) getClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) {
   330  	retries := 0
   331  START:
   332  	// Try to get a conn first
   333  	conn, err := p.acquire(region, addr, version)
   334  	if err != nil {
   335  		return nil, nil, fmt.Errorf("failed to get conn: %v", err)
   336  	}
   337  
   338  	// Get a client
   339  	client, err := conn.getClient()
   340  	if err != nil {
   341  		p.clearConn(conn)
   342  		p.releaseConn(conn)
   343  
   344  		// Try to redial, possible that the TCP session closed due to timeout
   345  		if retries == 0 {
   346  			retries++
   347  			goto START
   348  		}
   349  		return nil, nil, fmt.Errorf("failed to start stream: %v", err)
   350  	}
   351  	return conn, client, nil
   352  }
   353  
   354  // RPC is used to make an RPC call to a remote host
   355  func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
   356  	// Get a usable client
   357  	conn, sc, err := p.getClient(region, addr, version)
   358  	if err != nil {
   359  		return fmt.Errorf("rpc error: %v", err)
   360  	}
   361  
   362  	// Make the RPC call
   363  	err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply)
   364  	if err != nil {
   365  		sc.Close()
   366  		p.releaseConn(conn)
   367  		return fmt.Errorf("rpc error: %v", err)
   368  	}
   369  
   370  	// Done with the connection
   371  	conn.returnClient(sc)
   372  	p.releaseConn(conn)
   373  	return nil
   374  }
   375  
   376  // Reap is used to close conns open over maxTime
   377  func (p *ConnPool) reap() {
   378  	for {
   379  		// Sleep for a while
   380  		select {
   381  		case <-p.shutdownCh:
   382  			return
   383  		case <-time.After(time.Second):
   384  		}
   385  
   386  		// Reap all old conns
   387  		p.Lock()
   388  		var removed []string
   389  		now := time.Now()
   390  		for host, conn := range p.pool {
   391  			// Skip recently used connections
   392  			if now.Sub(conn.lastUsed) < p.maxTime {
   393  				continue
   394  			}
   395  
   396  			// Skip connections with active streams
   397  			if atomic.LoadInt32(&conn.refCount) > 0 {
   398  				continue
   399  			}
   400  
   401  			// Close the conn
   402  			conn.Close()
   403  
   404  			// Remove from pool
   405  			removed = append(removed, host)
   406  		}
   407  		for _, host := range removed {
   408  			delete(p.pool, host)
   409  		}
   410  		p.Unlock()
   411  	}
   412  }