github.com/kardianos/nomad@v0.1.3-0.20151022182107-b13df73ee850/nomad/pool.go (about)

     1  package nomad
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"net/rpc"
     9  	"sync"
    10  	"sync/atomic"
    11  	"time"
    12  
    13  	"github.com/hashicorp/consul/tlsutil"
    14  	"github.com/hashicorp/net-rpc-msgpackrpc"
    15  	"github.com/hashicorp/yamux"
    16  )
    17  
    18  // streamClient is used to wrap a stream with an RPC client
    19  type StreamClient struct {
    20  	stream net.Conn
    21  	codec  rpc.ClientCodec
    22  }
    23  
    24  func (sc *StreamClient) Close() {
    25  	sc.stream.Close()
    26  }
    27  
    28  // Conn is a pooled connection to a Nomad server
    29  type Conn struct {
    30  	refCount    int32
    31  	shouldClose int32
    32  
    33  	addr     net.Addr
    34  	session  *yamux.Session
    35  	lastUsed time.Time
    36  	version  int
    37  
    38  	pool *ConnPool
    39  
    40  	clients    *list.List
    41  	clientLock sync.Mutex
    42  }
    43  
    44  // markForUse does all the bookkeeping required to ready a connection for use.
    45  func (c *Conn) markForUse() {
    46  	c.lastUsed = time.Now()
    47  	atomic.AddInt32(&c.refCount, 1)
    48  }
    49  
    50  func (c *Conn) Close() error {
    51  	return c.session.Close()
    52  }
    53  
    54  // getClient is used to get a cached or new client
    55  func (c *Conn) getClient() (*StreamClient, error) {
    56  	// Check for cached client
    57  	c.clientLock.Lock()
    58  	front := c.clients.Front()
    59  	if front != nil {
    60  		c.clients.Remove(front)
    61  	}
    62  	c.clientLock.Unlock()
    63  	if front != nil {
    64  		return front.Value.(*StreamClient), nil
    65  	}
    66  
    67  	// Open a new session
    68  	stream, err := c.session.Open()
    69  	if err != nil {
    70  		return nil, err
    71  	}
    72  
    73  	// Create a client codec
    74  	codec := msgpackrpc.NewClientCodec(stream)
    75  
    76  	// Return a new stream client
    77  	sc := &StreamClient{
    78  		stream: stream,
    79  		codec:  codec,
    80  	}
    81  	return sc, nil
    82  }
    83  
    84  // returnStream is used when done with a stream
    85  // to allow re-use by a future RPC
    86  func (c *Conn) returnClient(client *StreamClient) {
    87  	didSave := false
    88  	c.clientLock.Lock()
    89  	if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 {
    90  		c.clients.PushFront(client)
    91  		didSave = true
    92  	}
    93  	c.clientLock.Unlock()
    94  	if !didSave {
    95  		client.Close()
    96  	}
    97  }
    98  
    99  // ConnPool is used to maintain a connection pool to other
   100  // Nomad servers. This is used to reduce the latency of
   101  // RPC requests between servers. It is only used to pool
   102  // connections in the rpcNomad mode. Raft connections
   103  // are pooled separately.
   104  type ConnPool struct {
   105  	sync.Mutex
   106  
   107  	// LogOutput is used to control logging
   108  	logOutput io.Writer
   109  
   110  	// The maximum time to keep a connection open
   111  	maxTime time.Duration
   112  
   113  	// The maximum number of open streams to keep
   114  	maxStreams int
   115  
   116  	// Pool maps an address to a open connection
   117  	pool map[string]*Conn
   118  
   119  	// limiter is used to throttle the number of connect attempts
   120  	// to a given address. The first thread will attempt a connection
   121  	// and put a channel in here, which all other threads will wait
   122  	// on to close.
   123  	limiter map[string]chan struct{}
   124  
   125  	// TLS wrapper
   126  	tlsWrap tlsutil.DCWrapper
   127  
   128  	// Used to indicate the pool is shutdown
   129  	shutdown   bool
   130  	shutdownCh chan struct{}
   131  }
   132  
   133  // NewPool is used to make a new connection pool
   134  // Maintain at most one connection per host, for up to maxTime.
   135  // Set maxTime to 0 to disable reaping. maxStreams is used to control
   136  // the number of idle streams allowed.
   137  // If TLS settings are provided outgoing connections use TLS.
   138  func NewPool(logOutput io.Writer, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.DCWrapper) *ConnPool {
   139  	pool := &ConnPool{
   140  		logOutput:  logOutput,
   141  		maxTime:    maxTime,
   142  		maxStreams: maxStreams,
   143  		pool:       make(map[string]*Conn),
   144  		limiter:    make(map[string]chan struct{}),
   145  		tlsWrap:    tlsWrap,
   146  		shutdownCh: make(chan struct{}),
   147  	}
   148  	if maxTime > 0 {
   149  		go pool.reap()
   150  	}
   151  	return pool
   152  }
   153  
   154  // Shutdown is used to close the connection pool
   155  func (p *ConnPool) Shutdown() error {
   156  	p.Lock()
   157  	defer p.Unlock()
   158  
   159  	for _, conn := range p.pool {
   160  		conn.Close()
   161  	}
   162  	p.pool = make(map[string]*Conn)
   163  
   164  	if p.shutdown {
   165  		return nil
   166  	}
   167  	p.shutdown = true
   168  	close(p.shutdownCh)
   169  	return nil
   170  }
   171  
   172  // Acquire is used to get a connection that is
   173  // pooled or to return a new connection
   174  func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) {
   175  	// Check to see if there's a pooled connection available. This is up
   176  	// here since it should the the vastly more common case than the rest
   177  	// of the code here.
   178  	p.Lock()
   179  	c := p.pool[addr.String()]
   180  	if c != nil {
   181  		c.markForUse()
   182  		p.Unlock()
   183  		return c, nil
   184  	}
   185  
   186  	// If not (while we are still locked), set up the throttling structure
   187  	// for this address, which will make everyone else wait until our
   188  	// attempt is done.
   189  	var wait chan struct{}
   190  	var ok bool
   191  	if wait, ok = p.limiter[addr.String()]; !ok {
   192  		wait = make(chan struct{})
   193  		p.limiter[addr.String()] = wait
   194  	}
   195  	isLeadThread := !ok
   196  	p.Unlock()
   197  
   198  	// If we are the lead thread, make the new connection and then wake
   199  	// everybody else up to see if we got it.
   200  	if isLeadThread {
   201  		c, err := p.getNewConn(region, addr, version)
   202  		p.Lock()
   203  		delete(p.limiter, addr.String())
   204  		close(wait)
   205  		if err != nil {
   206  			p.Unlock()
   207  			return nil, err
   208  		}
   209  
   210  		p.pool[addr.String()] = c
   211  		p.Unlock()
   212  		return c, nil
   213  	}
   214  
   215  	// Otherwise, wait for the lead thread to attempt the connection
   216  	// and use what's in the pool at that point.
   217  	select {
   218  	case <-p.shutdownCh:
   219  		return nil, fmt.Errorf("rpc error: shutdown")
   220  	case <-wait:
   221  	}
   222  
   223  	// See if the lead thread was able to get us a connection.
   224  	p.Lock()
   225  	if c := p.pool[addr.String()]; c != nil {
   226  		c.markForUse()
   227  		p.Unlock()
   228  		return c, nil
   229  	}
   230  
   231  	p.Unlock()
   232  	return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
   233  }
   234  
   235  // getNewConn is used to return a new connection
   236  func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, error) {
   237  	// Try to dial the conn
   238  	conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
   239  	if err != nil {
   240  		return nil, err
   241  	}
   242  
   243  	// Cast to TCPConn
   244  	if tcp, ok := conn.(*net.TCPConn); ok {
   245  		tcp.SetKeepAlive(true)
   246  		tcp.SetNoDelay(true)
   247  	}
   248  
   249  	// Check if TLS is enabled
   250  	if p.tlsWrap != nil {
   251  		// Switch the connection into TLS mode
   252  		if _, err := conn.Write([]byte{byte(rpcTLS)}); err != nil {
   253  			conn.Close()
   254  			return nil, err
   255  		}
   256  
   257  		// Wrap the connection in a TLS client
   258  		tlsConn, err := p.tlsWrap(region, conn)
   259  		if err != nil {
   260  			conn.Close()
   261  			return nil, err
   262  		}
   263  		conn = tlsConn
   264  	}
   265  
   266  	// Write the multiplex byte to set the mode
   267  	if _, err := conn.Write([]byte{byte(rpcMultiplex)}); err != nil {
   268  		conn.Close()
   269  		return nil, err
   270  	}
   271  
   272  	// Setup the logger
   273  	conf := yamux.DefaultConfig()
   274  	conf.LogOutput = p.logOutput
   275  
   276  	// Create a multiplexed session
   277  	session, err := yamux.Client(conn, conf)
   278  	if err != nil {
   279  		conn.Close()
   280  		return nil, err
   281  	}
   282  
   283  	// Wrap the connection
   284  	c := &Conn{
   285  		refCount: 1,
   286  		addr:     addr,
   287  		session:  session,
   288  		clients:  list.New(),
   289  		lastUsed: time.Now(),
   290  		version:  version,
   291  		pool:     p,
   292  	}
   293  	return c, nil
   294  }
   295  
   296  // clearConn is used to clear any cached connection, potentially in response to an erro
   297  func (p *ConnPool) clearConn(conn *Conn) {
   298  	// Ensure returned streams are closed
   299  	atomic.StoreInt32(&conn.shouldClose, 1)
   300  
   301  	// Clear from the cache
   302  	p.Lock()
   303  	if c, ok := p.pool[conn.addr.String()]; ok && c == conn {
   304  		delete(p.pool, conn.addr.String())
   305  	}
   306  	p.Unlock()
   307  
   308  	// Close down immediately if idle
   309  	if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 {
   310  		conn.Close()
   311  	}
   312  }
   313  
   314  // releaseConn is invoked when we are done with a conn to reduce the ref count
   315  func (p *ConnPool) releaseConn(conn *Conn) {
   316  	refCount := atomic.AddInt32(&conn.refCount, -1)
   317  	if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 {
   318  		conn.Close()
   319  	}
   320  }
   321  
   322  // getClient is used to get a usable client for an address and protocol version
   323  func (p *ConnPool) getClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) {
   324  	retries := 0
   325  START:
   326  	// Try to get a conn first
   327  	conn, err := p.acquire(region, addr, version)
   328  	if err != nil {
   329  		return nil, nil, fmt.Errorf("failed to get conn: %v", err)
   330  	}
   331  
   332  	// Get a client
   333  	client, err := conn.getClient()
   334  	if err != nil {
   335  		p.clearConn(conn)
   336  		p.releaseConn(conn)
   337  
   338  		// Try to redial, possible that the TCP session closed due to timeout
   339  		if retries == 0 {
   340  			retries++
   341  			goto START
   342  		}
   343  		return nil, nil, fmt.Errorf("failed to start stream: %v", err)
   344  	}
   345  	return conn, client, nil
   346  }
   347  
   348  // RPC is used to make an RPC call to a remote host
   349  func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
   350  	// Get a usable client
   351  	conn, sc, err := p.getClient(region, addr, version)
   352  	if err != nil {
   353  		return fmt.Errorf("rpc error: %v", err)
   354  	}
   355  
   356  	// Make the RPC call
   357  	err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply)
   358  	if err != nil {
   359  		sc.Close()
   360  		p.releaseConn(conn)
   361  		return fmt.Errorf("rpc error: %v", err)
   362  	}
   363  
   364  	// Done with the connection
   365  	conn.returnClient(sc)
   366  	p.releaseConn(conn)
   367  	return nil
   368  }
   369  
   370  // Reap is used to close conns open over maxTime
   371  func (p *ConnPool) reap() {
   372  	for {
   373  		// Sleep for a while
   374  		select {
   375  		case <-p.shutdownCh:
   376  			return
   377  		case <-time.After(time.Second):
   378  		}
   379  
   380  		// Reap all old conns
   381  		p.Lock()
   382  		var removed []string
   383  		now := time.Now()
   384  		for host, conn := range p.pool {
   385  			// Skip recently used connections
   386  			if now.Sub(conn.lastUsed) < p.maxTime {
   387  				continue
   388  			}
   389  
   390  			// Skip connections with active streams
   391  			if atomic.LoadInt32(&conn.refCount) > 0 {
   392  				continue
   393  			}
   394  
   395  			// Close the conn
   396  			conn.Close()
   397  
   398  			// Remove from pool
   399  			removed = append(removed, host)
   400  		}
   401  		for _, host := range removed {
   402  			delete(p.pool, host)
   403  		}
   404  		p.Unlock()
   405  	}
   406  }