github.com/Ilhicas/nomad@v1.0.4-0.20210304152020-e86851182bc3/helper/pool/pool.go (about)

     1  package pool
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"net"
     9  	"net/rpc"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	"github.com/hashicorp/consul/lib"
    15  	hclog "github.com/hashicorp/go-hclog"
    16  	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
    17  	"github.com/hashicorp/nomad/helper/tlsutil"
    18  	"github.com/hashicorp/nomad/nomad/structs"
    19  	"github.com/hashicorp/yamux"
    20  )
    21  
    22  // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls.
    23  func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
    24  	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle)
    25  }
    26  
    27  // NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs.
    28  func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
    29  	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle)
    30  }
    31  
    32  // streamClient is used to wrap a stream with an RPC client
    33  type StreamClient struct {
    34  	stream net.Conn
    35  	codec  rpc.ClientCodec
    36  }
    37  
    38  func (sc *StreamClient) Close() {
    39  	sc.stream.Close()
    40  	sc.codec.Close()
    41  }
    42  
    43  // Conn is a pooled connection to a Nomad server
    44  type Conn struct {
    45  	refCount    int32
    46  	shouldClose int32
    47  
    48  	addr     net.Addr
    49  	session  *yamux.Session
    50  	lastUsed time.Time
    51  	version  int
    52  
    53  	pool *ConnPool
    54  
    55  	clients    *list.List
    56  	clientLock sync.Mutex
    57  }
    58  
    59  // markForUse does all the bookkeeping required to ready a connection for use.
    60  func (c *Conn) markForUse() {
    61  	c.lastUsed = time.Now()
    62  	atomic.AddInt32(&c.refCount, 1)
    63  }
    64  
    65  func (c *Conn) Close() error {
    66  	return c.session.Close()
    67  }
    68  
    69  // getClient is used to get a cached or new client
    70  func (c *Conn) getRPCClient() (*StreamClient, error) {
    71  	// Check for cached client
    72  	c.clientLock.Lock()
    73  	front := c.clients.Front()
    74  	if front != nil {
    75  		c.clients.Remove(front)
    76  	}
    77  	c.clientLock.Unlock()
    78  	if front != nil {
    79  		return front.Value.(*StreamClient), nil
    80  	}
    81  
    82  	// Open a new session
    83  	stream, err := c.session.Open()
    84  	if err != nil {
    85  		return nil, err
    86  	}
    87  
    88  	if _, err := stream.Write([]byte{byte(RpcNomad)}); err != nil {
    89  		stream.Close()
    90  		return nil, err
    91  	}
    92  
    93  	// Create a client codec
    94  	codec := NewClientCodec(stream)
    95  
    96  	// Return a new stream client
    97  	sc := &StreamClient{
    98  		stream: stream,
    99  		codec:  codec,
   100  	}
   101  	return sc, nil
   102  }
   103  
   104  // returnClient is used when done with a stream
   105  // to allow re-use by a future RPC
   106  func (c *Conn) returnClient(client *StreamClient) {
   107  	didSave := false
   108  	c.clientLock.Lock()
   109  	if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 {
   110  		c.clients.PushFront(client)
   111  		didSave = true
   112  
   113  		// If this is a Yamux stream, shrink the internal buffers so that
   114  		// we can GC the idle memory
   115  		if ys, ok := client.stream.(*yamux.Stream); ok {
   116  			ys.Shrink()
   117  		}
   118  	}
   119  	c.clientLock.Unlock()
   120  	if !didSave {
   121  		client.Close()
   122  	}
   123  }
   124  
   125  // ConnPool is used to maintain a connection pool to other
   126  // Nomad servers. This is used to reduce the latency of
   127  // RPC requests between servers. It is only used to pool
   128  // connections in the rpcNomad mode. Raft connections
   129  // are pooled separately.
   130  type ConnPool struct {
   131  	sync.Mutex
   132  
   133  	// logger is the logger to be used
   134  	logger *log.Logger
   135  
   136  	// The maximum time to keep a connection open
   137  	maxTime time.Duration
   138  
   139  	// The maximum number of open streams to keep
   140  	maxStreams int
   141  
   142  	// Pool maps an address to a open connection
   143  	pool map[string]*Conn
   144  
   145  	// limiter is used to throttle the number of connect attempts
   146  	// to a given address. The first thread will attempt a connection
   147  	// and put a channel in here, which all other threads will wait
   148  	// on to close.
   149  	limiter map[string]chan struct{}
   150  
   151  	// TLS wrapper
   152  	tlsWrap tlsutil.RegionWrapper
   153  
   154  	// Used to indicate the pool is shutdown
   155  	shutdown   bool
   156  	shutdownCh chan struct{}
   157  
   158  	// connListener is used to notify a potential listener of a new connection
   159  	// being made.
   160  	connListener chan<- *yamux.Session
   161  }
   162  
   163  // NewPool is used to make a new connection pool
   164  // Maintain at most one connection per host, for up to maxTime.
   165  // Set maxTime to 0 to disable reaping. maxStreams is used to control
   166  // the number of idle streams allowed.
   167  // If TLS settings are provided outgoing connections use TLS.
   168  func NewPool(logger hclog.Logger, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool {
   169  	pool := &ConnPool{
   170  		logger:     logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}),
   171  		maxTime:    maxTime,
   172  		maxStreams: maxStreams,
   173  		pool:       make(map[string]*Conn),
   174  		limiter:    make(map[string]chan struct{}),
   175  		tlsWrap:    tlsWrap,
   176  		shutdownCh: make(chan struct{}),
   177  	}
   178  	if maxTime > 0 {
   179  		go pool.reap()
   180  	}
   181  	return pool
   182  }
   183  
   184  // Shutdown is used to close the connection pool
   185  func (p *ConnPool) Shutdown() error {
   186  	p.Lock()
   187  	defer p.Unlock()
   188  
   189  	for _, conn := range p.pool {
   190  		conn.Close()
   191  	}
   192  	p.pool = make(map[string]*Conn)
   193  
   194  	if p.shutdown {
   195  		return nil
   196  	}
   197  
   198  	if p.connListener != nil {
   199  		close(p.connListener)
   200  		p.connListener = nil
   201  	}
   202  
   203  	p.shutdown = true
   204  	close(p.shutdownCh)
   205  	return nil
   206  }
   207  
   208  // ReloadTLS reloads TLS configuration on the fly
   209  func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) {
   210  	p.Lock()
   211  	defer p.Unlock()
   212  
   213  	oldPool := p.pool
   214  	for _, conn := range oldPool {
   215  		conn.Close()
   216  	}
   217  	p.pool = make(map[string]*Conn)
   218  	p.tlsWrap = tlsWrap
   219  }
   220  
   221  // SetConnListener is used to listen to new connections being made. The
   222  // channel will be closed when the conn pool is closed or a new listener is set.
   223  func (p *ConnPool) SetConnListener(l chan<- *yamux.Session) {
   224  	p.Lock()
   225  	defer p.Unlock()
   226  
   227  	// Close the old listener
   228  	if p.connListener != nil {
   229  		close(p.connListener)
   230  	}
   231  
   232  	// Store the new listener
   233  	p.connListener = l
   234  }
   235  
   236  // Acquire is used to get a connection that is
   237  // pooled or to return a new connection
   238  func (p *ConnPool) acquire(region string, addr net.Addr, version int) (*Conn, error) {
   239  	// Check to see if there's a pooled connection available. This is up
   240  	// here since it should the vastly more common case than the rest
   241  	// of the code here.
   242  	p.Lock()
   243  	c := p.pool[addr.String()]
   244  	if c != nil {
   245  		c.markForUse()
   246  		p.Unlock()
   247  		return c, nil
   248  	}
   249  
   250  	// If not (while we are still locked), set up the throttling structure
   251  	// for this address, which will make everyone else wait until our
   252  	// attempt is done.
   253  	var wait chan struct{}
   254  	var ok bool
   255  	if wait, ok = p.limiter[addr.String()]; !ok {
   256  		wait = make(chan struct{})
   257  		p.limiter[addr.String()] = wait
   258  	}
   259  	isLeadThread := !ok
   260  	p.Unlock()
   261  
   262  	// If we are the lead thread, make the new connection and then wake
   263  	// everybody else up to see if we got it.
   264  	if isLeadThread {
   265  		c, err := p.getNewConn(region, addr, version)
   266  		p.Lock()
   267  		delete(p.limiter, addr.String())
   268  		close(wait)
   269  		if err != nil {
   270  			p.Unlock()
   271  			return nil, err
   272  		}
   273  
   274  		p.pool[addr.String()] = c
   275  
   276  		// If there is a connection listener, notify them of the new connection.
   277  		if p.connListener != nil {
   278  			select {
   279  			case p.connListener <- c.session:
   280  			default:
   281  			}
   282  		}
   283  
   284  		p.Unlock()
   285  		return c, nil
   286  	}
   287  
   288  	// Otherwise, wait for the lead thread to attempt the connection
   289  	// and use what's in the pool at that point.
   290  	select {
   291  	case <-p.shutdownCh:
   292  		return nil, fmt.Errorf("rpc error: shutdown")
   293  	case <-wait:
   294  	}
   295  
   296  	// See if the lead thread was able to get us a connection.
   297  	p.Lock()
   298  	if c := p.pool[addr.String()]; c != nil {
   299  		c.markForUse()
   300  		p.Unlock()
   301  		return c, nil
   302  	}
   303  
   304  	p.Unlock()
   305  	return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
   306  }
   307  
   308  // getNewConn is used to return a new connection
   309  func (p *ConnPool) getNewConn(region string, addr net.Addr, version int) (*Conn, error) {
   310  	// Try to dial the conn
   311  	conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
   312  	if err != nil {
   313  		return nil, err
   314  	}
   315  
   316  	// Cast to TCPConn
   317  	if tcp, ok := conn.(*net.TCPConn); ok {
   318  		tcp.SetKeepAlive(true)
   319  		tcp.SetNoDelay(true)
   320  	}
   321  
   322  	// Check if TLS is enabled
   323  	if p.tlsWrap != nil {
   324  		// Switch the connection into TLS mode
   325  		if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil {
   326  			conn.Close()
   327  			return nil, err
   328  		}
   329  
   330  		// Wrap the connection in a TLS client
   331  		tlsConn, err := p.tlsWrap(region, conn)
   332  		if err != nil {
   333  			conn.Close()
   334  			return nil, err
   335  		}
   336  		conn = tlsConn
   337  	}
   338  
   339  	// Write the multiplex byte to set the mode
   340  	if _, err := conn.Write([]byte{byte(RpcMultiplexV2)}); err != nil {
   341  		conn.Close()
   342  		return nil, err
   343  	}
   344  
   345  	// Setup the logger
   346  	conf := yamux.DefaultConfig()
   347  	conf.LogOutput = nil
   348  	conf.Logger = p.logger
   349  
   350  	// Create a multiplexed session
   351  	session, err := yamux.Client(conn, conf)
   352  	if err != nil {
   353  		conn.Close()
   354  		return nil, err
   355  	}
   356  
   357  	// Wrap the connection
   358  	c := &Conn{
   359  		refCount: 1,
   360  		addr:     addr,
   361  		session:  session,
   362  		clients:  list.New(),
   363  		lastUsed: time.Now(),
   364  		version:  version,
   365  		pool:     p,
   366  	}
   367  	return c, nil
   368  }
   369  
   370  // clearConn is used to clear any cached connection, potentially in response to
   371  // an error
   372  func (p *ConnPool) clearConn(conn *Conn) {
   373  	// Ensure returned streams are closed
   374  	atomic.StoreInt32(&conn.shouldClose, 1)
   375  
   376  	// Clear from the cache
   377  	p.Lock()
   378  	if c, ok := p.pool[conn.addr.String()]; ok && c == conn {
   379  		delete(p.pool, conn.addr.String())
   380  	}
   381  	p.Unlock()
   382  
   383  	// Close down immediately if idle
   384  	if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 {
   385  		conn.Close()
   386  	}
   387  }
   388  
   389  // releaseConn is invoked when we are done with a conn to reduce the ref count
   390  func (p *ConnPool) releaseConn(conn *Conn) {
   391  	refCount := atomic.AddInt32(&conn.refCount, -1)
   392  	if refCount == 0 && atomic.LoadInt32(&conn.shouldClose) == 1 {
   393  		conn.Close()
   394  	}
   395  }
   396  
   397  // getClient is used to get a usable client for an address and protocol version
   398  func (p *ConnPool) getRPCClient(region string, addr net.Addr, version int) (*Conn, *StreamClient, error) {
   399  	retries := 0
   400  START:
   401  	// Try to get a conn first
   402  	conn, err := p.acquire(region, addr, version)
   403  	if err != nil {
   404  		return nil, nil, fmt.Errorf("failed to get conn: %v", err)
   405  	}
   406  
   407  	// Get a client
   408  	client, err := conn.getRPCClient()
   409  	if err != nil {
   410  		p.clearConn(conn)
   411  		p.releaseConn(conn)
   412  
   413  		// Try to redial, possible that the TCP session closed due to timeout
   414  		if retries == 0 {
   415  			retries++
   416  			goto START
   417  		}
   418  		return nil, nil, fmt.Errorf("failed to start stream: %v", err)
   419  	}
   420  	return conn, client, nil
   421  }
   422  
   423  // StreamingRPC is used to make an streaming RPC call.  Callers must
   424  // close the connection when done.
   425  func (p *ConnPool) StreamingRPC(region string, addr net.Addr, version int) (net.Conn, error) {
   426  	conn, err := p.acquire(region, addr, version)
   427  	if err != nil {
   428  		return nil, fmt.Errorf("failed to get conn: %v", err)
   429  	}
   430  
   431  	s, err := conn.session.Open()
   432  	if err != nil {
   433  		return nil, fmt.Errorf("failed to open a streaming connection: %v", err)
   434  	}
   435  
   436  	if _, err := s.Write([]byte{byte(RpcStreaming)}); err != nil {
   437  		conn.Close()
   438  		return nil, err
   439  	}
   440  
   441  	return s, nil
   442  }
   443  
   444  // RPC is used to make an RPC call to a remote host
   445  func (p *ConnPool) RPC(region string, addr net.Addr, version int, method string, args interface{}, reply interface{}) error {
   446  	// Get a usable client
   447  	conn, sc, err := p.getRPCClient(region, addr, version)
   448  	if err != nil {
   449  		return fmt.Errorf("rpc error: %w", err)
   450  	}
   451  
   452  	// Make the RPC call
   453  	err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply)
   454  	if err != nil {
   455  		sc.Close()
   456  
   457  		// If we read EOF, the session is toast. Clear it and open a
   458  		// new session next time
   459  		// See https://github.com/hashicorp/consul/blob/v1.6.3/agent/pool/pool.go#L471-L477
   460  		if lib.IsErrEOF(err) {
   461  			p.clearConn(conn)
   462  		}
   463  
   464  		p.releaseConn(conn)
   465  
   466  		// If the error is an RPC Coded error
   467  		// return the coded error without wrapping
   468  		if structs.IsErrRPCCoded(err) {
   469  			return err
   470  		}
   471  
   472  		// TODO wrap with RPCCoded error instead
   473  		return fmt.Errorf("rpc error: %w", err)
   474  	}
   475  
   476  	// Done with the connection
   477  	conn.returnClient(sc)
   478  	p.releaseConn(conn)
   479  	return nil
   480  }
   481  
   482  // Reap is used to close conns open over maxTime
   483  func (p *ConnPool) reap() {
   484  	for {
   485  		// Sleep for a while
   486  		select {
   487  		case <-p.shutdownCh:
   488  			return
   489  		case <-time.After(time.Second):
   490  		}
   491  
   492  		// Reap all old conns
   493  		p.Lock()
   494  		var removed []string
   495  		now := time.Now()
   496  		for host, conn := range p.pool {
   497  			// Skip recently used connections
   498  			if now.Sub(conn.lastUsed) < p.maxTime {
   499  				continue
   500  			}
   501  
   502  			// Skip connections with active streams
   503  			if atomic.LoadInt32(&conn.refCount) > 0 {
   504  				continue
   505  			}
   506  
   507  			// Close the conn
   508  			conn.Close()
   509  
   510  			// Remove from pool
   511  			removed = append(removed, host)
   512  		}
   513  		for _, host := range removed {
   514  			delete(p.pool, host)
   515  		}
   516  		p.Unlock()
   517  	}
   518  }