github.com/anth0d/nomad@v0.0.0-20221214183521-ae3a0a2cad06/helper/pool/pool.go (about)

     1  package pool
     2  
     3  import (
     4  	"container/list"
     5  	"fmt"
     6  	"io"
     7  	"log"
     8  	"net"
     9  	"net/rpc"
    10  	"sync"
    11  	"sync/atomic"
    12  	"time"
    13  
    14  	hclog "github.com/hashicorp/go-hclog"
    15  	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
    16  	"github.com/hashicorp/nomad/helper"
    17  	"github.com/hashicorp/nomad/helper/tlsutil"
    18  	"github.com/hashicorp/nomad/nomad/structs"
    19  	"github.com/hashicorp/yamux"
    20  )
    21  
    22  // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls.
    23  func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
    24  	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle)
    25  }
    26  
    27  // NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs.
    28  func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
    29  	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle)
    30  }
    31  
    32  // streamClient is used to wrap a stream with an RPC client
    33  type StreamClient struct {
    34  	stream net.Conn
    35  	codec  rpc.ClientCodec
    36  }
    37  
    38  func (sc *StreamClient) Close() {
    39  	sc.stream.Close()
    40  	sc.codec.Close()
    41  }
    42  
    43  // Conn is a pooled connection to a Nomad server
    44  type Conn struct {
    45  	refCount    int32
    46  	shouldClose int32
    47  
    48  	addr     net.Addr
    49  	session  *yamux.Session
    50  	lastUsed atomic.Pointer[time.Time]
    51  
    52  	pool *ConnPool
    53  
    54  	clients    *list.List
    55  	clientLock sync.Mutex
    56  }
    57  
    58  // markForUse does all the bookkeeping required to ready a connection for use,
    59  // and ensure that active connections don't get reaped.
    60  func (c *Conn) markForUse() {
    61  	now := time.Now()
    62  	c.lastUsed.Store(&now)
    63  	atomic.AddInt32(&c.refCount, 1)
    64  }
    65  
    66  // releaseUse is the complement of `markForUse`, to free up the reference count
    67  func (c *Conn) releaseUse() {
    68  	refCount := atomic.AddInt32(&c.refCount, -1)
    69  	if refCount == 0 && atomic.LoadInt32(&c.shouldClose) == 1 {
    70  		c.Close()
    71  	}
    72  }
    73  
    74  func (c *Conn) Close() error {
    75  	return c.session.Close()
    76  }
    77  
    78  // getClient is used to get a cached or new client
    79  func (c *Conn) getRPCClient() (*StreamClient, error) {
    80  	// Check for cached client
    81  	c.clientLock.Lock()
    82  	front := c.clients.Front()
    83  	if front != nil {
    84  		c.clients.Remove(front)
    85  	}
    86  	c.clientLock.Unlock()
    87  	if front != nil {
    88  		return front.Value.(*StreamClient), nil
    89  	}
    90  
    91  	// Open a new session
    92  	stream, err := c.session.Open()
    93  	if err != nil {
    94  		return nil, err
    95  	}
    96  
    97  	if _, err := stream.Write([]byte{byte(RpcNomad)}); err != nil {
    98  		stream.Close()
    99  		return nil, err
   100  	}
   101  
   102  	// Create a client codec
   103  	codec := NewClientCodec(stream)
   104  
   105  	// Return a new stream client
   106  	sc := &StreamClient{
   107  		stream: stream,
   108  		codec:  codec,
   109  	}
   110  	return sc, nil
   111  }
   112  
   113  // returnClient is used when done with a stream
   114  // to allow re-use by a future RPC
   115  func (c *Conn) returnClient(client *StreamClient) {
   116  	didSave := false
   117  	c.clientLock.Lock()
   118  	if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 {
   119  		c.clients.PushFront(client)
   120  		didSave = true
   121  
   122  		// If this is a Yamux stream, shrink the internal buffers so that
   123  		// we can GC the idle memory
   124  		if ys, ok := client.stream.(*yamux.Stream); ok {
   125  			ys.Shrink()
   126  		}
   127  	}
   128  	c.clientLock.Unlock()
   129  	if !didSave {
   130  		client.Close()
   131  	}
   132  }
   133  
   134  func (c *Conn) IsClosed() bool {
   135  	return c.session.IsClosed()
   136  }
   137  
   138  func (c *Conn) AcceptStream() (net.Conn, error) {
   139  	s, err := c.session.AcceptStream()
   140  	if err != nil {
   141  		return nil, err
   142  	}
   143  
   144  	c.markForUse()
   145  	return &incomingStream{
   146  		Stream: s,
   147  		parent: c,
   148  	}, nil
   149  }
   150  
   151  // incomingStream wraps yamux.Stream but frees the underlying yamux.Session
   152  // when closed
   153  type incomingStream struct {
   154  	*yamux.Stream
   155  
   156  	parent *Conn
   157  }
   158  
   159  func (s *incomingStream) Close() error {
   160  	err := s.Stream.Close()
   161  
   162  	// always release parent even if error
   163  	s.parent.releaseUse()
   164  
   165  	return err
   166  }
   167  
   168  // ConnPool is used to maintain a connection pool to other
   169  // Nomad servers. This is used to reduce the latency of
   170  // RPC requests between servers. It is only used to pool
   171  // connections in the rpcNomad mode. Raft connections
   172  // are pooled separately.
   173  type ConnPool struct {
   174  	sync.Mutex
   175  
   176  	// logger is the logger to be used
   177  	logger *log.Logger
   178  
   179  	// The maximum time to keep a connection open
   180  	maxTime time.Duration
   181  
   182  	// The maximum number of open streams to keep
   183  	maxStreams int
   184  
   185  	// Pool maps an address to a open connection
   186  	pool map[string]*Conn
   187  
   188  	// limiter is used to throttle the number of connect attempts
   189  	// to a given address. The first thread will attempt a connection
   190  	// and put a channel in here, which all other threads will wait
   191  	// on to close.
   192  	limiter map[string]chan struct{}
   193  
   194  	// TLS wrapper
   195  	tlsWrap tlsutil.RegionWrapper
   196  
   197  	// Used to indicate the pool is shutdown
   198  	shutdown   bool
   199  	shutdownCh chan struct{}
   200  
   201  	// connListener is used to notify a potential listener of a new connection
   202  	// being made.
   203  	connListener chan<- *Conn
   204  }
   205  
   206  // NewPool is used to make a new connection pool
   207  // Maintain at most one connection per host, for up to maxTime.
   208  // Set maxTime to 0 to disable reaping. maxStreams is used to control
   209  // the number of idle streams allowed.
   210  // If TLS settings are provided outgoing connections use TLS.
   211  func NewPool(logger hclog.Logger, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool {
   212  	pool := &ConnPool{
   213  		logger:     logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}),
   214  		maxTime:    maxTime,
   215  		maxStreams: maxStreams,
   216  		pool:       make(map[string]*Conn),
   217  		limiter:    make(map[string]chan struct{}),
   218  		tlsWrap:    tlsWrap,
   219  		shutdownCh: make(chan struct{}),
   220  	}
   221  	if maxTime > 0 {
   222  		go pool.reap()
   223  	}
   224  	return pool
   225  }
   226  
   227  // Shutdown is used to close the connection pool
   228  func (p *ConnPool) Shutdown() error {
   229  	p.Lock()
   230  	defer p.Unlock()
   231  
   232  	for _, conn := range p.pool {
   233  		conn.Close()
   234  	}
   235  	p.pool = make(map[string]*Conn)
   236  
   237  	if p.shutdown {
   238  		return nil
   239  	}
   240  
   241  	if p.connListener != nil {
   242  		close(p.connListener)
   243  		p.connListener = nil
   244  	}
   245  
   246  	p.shutdown = true
   247  	close(p.shutdownCh)
   248  	return nil
   249  }
   250  
   251  // ReloadTLS reloads TLS configuration on the fly
   252  func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) {
   253  	p.Lock()
   254  	defer p.Unlock()
   255  
   256  	oldPool := p.pool
   257  	for _, conn := range oldPool {
   258  		conn.Close()
   259  	}
   260  	p.pool = make(map[string]*Conn)
   261  	p.tlsWrap = tlsWrap
   262  }
   263  
   264  // SetConnListener is used to listen to new connections being made. The
   265  // channel will be closed when the conn pool is closed or a new listener is set.
   266  func (p *ConnPool) SetConnListener(l chan<- *Conn) {
   267  	p.Lock()
   268  	defer p.Unlock()
   269  
   270  	// Close the old listener
   271  	if p.connListener != nil {
   272  		close(p.connListener)
   273  	}
   274  
   275  	// Store the new listener
   276  	p.connListener = l
   277  }
   278  
   279  // Acquire is used to get a connection that is
   280  // pooled or to return a new connection
   281  func (p *ConnPool) acquire(region string, addr net.Addr) (*Conn, error) {
   282  	// Check to see if there's a pooled connection available. This is up
   283  	// here since it should the vastly more common case than the rest
   284  	// of the code here.
   285  	p.Lock()
   286  	c := p.pool[addr.String()]
   287  	if c != nil {
   288  		c.markForUse()
   289  		p.Unlock()
   290  		return c, nil
   291  	}
   292  
   293  	// If not (while we are still locked), set up the throttling structure
   294  	// for this address, which will make everyone else wait until our
   295  	// attempt is done.
   296  	var wait chan struct{}
   297  	var ok bool
   298  	if wait, ok = p.limiter[addr.String()]; !ok {
   299  		wait = make(chan struct{})
   300  		p.limiter[addr.String()] = wait
   301  	}
   302  	isLeadThread := !ok
   303  	p.Unlock()
   304  
   305  	// If we are the lead thread, make the new connection and then wake
   306  	// everybody else up to see if we got it.
   307  	if isLeadThread {
   308  		c, err := p.getNewConn(region, addr)
   309  		p.Lock()
   310  		delete(p.limiter, addr.String())
   311  		close(wait)
   312  		if err != nil {
   313  			p.Unlock()
   314  			return nil, err
   315  		}
   316  
   317  		p.pool[addr.String()] = c
   318  
   319  		// If there is a connection listener, notify them of the new connection.
   320  		if p.connListener != nil {
   321  			select {
   322  			case p.connListener <- c:
   323  			default:
   324  			}
   325  		}
   326  
   327  		p.Unlock()
   328  		return c, nil
   329  	}
   330  
   331  	// Otherwise, wait for the lead thread to attempt the connection
   332  	// and use what's in the pool at that point.
   333  	select {
   334  	case <-p.shutdownCh:
   335  		return nil, fmt.Errorf("rpc error: shutdown")
   336  	case <-wait:
   337  	}
   338  
   339  	// See if the lead thread was able to get us a connection.
   340  	p.Lock()
   341  	if c := p.pool[addr.String()]; c != nil {
   342  		c.markForUse()
   343  		p.Unlock()
   344  		return c, nil
   345  	}
   346  
   347  	p.Unlock()
   348  	return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
   349  }
   350  
   351  // getNewConn is used to return a new connection
   352  func (p *ConnPool) getNewConn(region string, addr net.Addr) (*Conn, error) {
   353  	// Try to dial the conn
   354  	conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
   355  	if err != nil {
   356  		return nil, err
   357  	}
   358  
   359  	// Cast to TCPConn
   360  	if tcp, ok := conn.(*net.TCPConn); ok {
   361  		tcp.SetKeepAlive(true)
   362  		tcp.SetNoDelay(true)
   363  	}
   364  
   365  	// Check if TLS is enabled
   366  	if p.tlsWrap != nil {
   367  		// Switch the connection into TLS mode
   368  		if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil {
   369  			conn.Close()
   370  			return nil, err
   371  		}
   372  
   373  		// Wrap the connection in a TLS client
   374  		tlsConn, err := p.tlsWrap(region, conn)
   375  		if err != nil {
   376  			conn.Close()
   377  			return nil, err
   378  		}
   379  		conn = tlsConn
   380  	}
   381  
   382  	// Write the multiplex byte to set the mode
   383  	if _, err := conn.Write([]byte{byte(RpcMultiplexV2)}); err != nil {
   384  		conn.Close()
   385  		return nil, err
   386  	}
   387  
   388  	// Setup the logger
   389  	conf := yamux.DefaultConfig()
   390  	conf.LogOutput = nil
   391  	conf.Logger = p.logger
   392  
   393  	// Create a multiplexed session
   394  	session, err := yamux.Client(conn, conf)
   395  	if err != nil {
   396  		conn.Close()
   397  		return nil, err
   398  	}
   399  
   400  	// Wrap the connection
   401  	c := &Conn{
   402  		refCount: 1,
   403  		addr:     addr,
   404  		session:  session,
   405  		clients:  list.New(),
   406  		lastUsed: atomic.Pointer[time.Time]{},
   407  		pool:     p,
   408  	}
   409  
   410  	now := time.Now()
   411  	c.lastUsed.Store(&now)
   412  	return c, nil
   413  }
   414  
   415  // clearConn is used to clear any cached connection, potentially in response to
   416  // an error
   417  func (p *ConnPool) clearConn(conn *Conn) {
   418  	// Ensure returned streams are closed
   419  	atomic.StoreInt32(&conn.shouldClose, 1)
   420  
   421  	// Clear from the cache
   422  	p.Lock()
   423  	if c, ok := p.pool[conn.addr.String()]; ok && c == conn {
   424  		delete(p.pool, conn.addr.String())
   425  	}
   426  	p.Unlock()
   427  
   428  	// Close down immediately if idle
   429  	if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 {
   430  		conn.Close()
   431  	}
   432  }
   433  
   434  // getClient is used to get a usable client for an address
   435  func (p *ConnPool) getRPCClient(region string, addr net.Addr) (*Conn, *StreamClient, error) {
   436  	retries := 0
   437  START:
   438  	// Try to get a conn first
   439  	conn, err := p.acquire(region, addr)
   440  	if err != nil {
   441  		return nil, nil, fmt.Errorf("failed to get conn: %v", err)
   442  	}
   443  
   444  	// Get a client
   445  	client, err := conn.getRPCClient()
   446  	if err != nil {
   447  		p.clearConn(conn)
   448  		conn.releaseUse()
   449  
   450  		// Try to redial, possible that the TCP session closed due to timeout
   451  		if retries == 0 {
   452  			retries++
   453  			goto START
   454  		}
   455  		return nil, nil, fmt.Errorf("failed to start stream: %v", err)
   456  	}
   457  	return conn, client, nil
   458  }
   459  
   460  // StreamingRPC is used to make an streaming RPC call.  Callers must
   461  // close the connection when done.
   462  func (p *ConnPool) StreamingRPC(region string, addr net.Addr) (net.Conn, error) {
   463  	conn, err := p.acquire(region, addr)
   464  	if err != nil {
   465  		return nil, fmt.Errorf("failed to get conn: %v", err)
   466  	}
   467  
   468  	s, err := conn.session.Open()
   469  	if err != nil {
   470  		return nil, fmt.Errorf("failed to open a streaming connection: %v", err)
   471  	}
   472  
   473  	if _, err := s.Write([]byte{byte(RpcStreaming)}); err != nil {
   474  		conn.Close()
   475  		return nil, err
   476  	}
   477  
   478  	return s, nil
   479  }
   480  
   481  // RPC is used to make an RPC call to a remote host
   482  func (p *ConnPool) RPC(region string, addr net.Addr, method string, args interface{}, reply interface{}) error {
   483  	// Get a usable client
   484  	conn, sc, err := p.getRPCClient(region, addr)
   485  	if err != nil {
   486  		return fmt.Errorf("rpc error: %w", err)
   487  	}
   488  	defer conn.releaseUse()
   489  
   490  	// Make the RPC call
   491  	err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply)
   492  	if err != nil {
   493  		sc.Close()
   494  
   495  		// If we read EOF, the session is toast. Clear it and open a
   496  		// new session next time
   497  		// See https://github.com/hashicorp/consul/blob/v1.6.3/agent/pool/pool.go#L471-L477
   498  		if helper.IsErrEOF(err) {
   499  			p.clearConn(conn)
   500  		}
   501  
   502  		// If the error is an RPC Coded error
   503  		// return the coded error without wrapping
   504  		if structs.IsErrRPCCoded(err) {
   505  			return err
   506  		}
   507  
   508  		// TODO wrap with RPCCoded error instead
   509  		return fmt.Errorf("rpc error: %w", err)
   510  	}
   511  
   512  	// Done with the connection
   513  	conn.returnClient(sc)
   514  	return nil
   515  }
   516  
   517  // Reap is used to close conns open over maxTime
   518  func (p *ConnPool) reap() {
   519  	for {
   520  		// Sleep for a while
   521  		select {
   522  		case <-p.shutdownCh:
   523  			return
   524  		case <-time.After(time.Second):
   525  		}
   526  
   527  		// Reap all old conns
   528  		p.Lock()
   529  		var removed []string
   530  		now := time.Now()
   531  		for host, conn := range p.pool {
   532  			// Skip recently used connections
   533  			if now.Sub(*conn.lastUsed.Load()) < p.maxTime {
   534  				continue
   535  			}
   536  
   537  			// Skip connections with active streams
   538  			if atomic.LoadInt32(&conn.refCount) > 0 {
   539  				continue
   540  			}
   541  
   542  			// Close the conn
   543  			conn.Close()
   544  
   545  			// Remove from pool
   546  			removed = append(removed, host)
   547  		}
   548  		for _, host := range removed {
   549  			delete(p.pool, host)
   550  		}
   551  		p.Unlock()
   552  	}
   553  }