github.com/hernad/nomad@v1.6.112/helper/pool/pool.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package pool
     5  
     6  import (
     7  	"container/list"
     8  	"fmt"
     9  	"io"
    10  	"log"
    11  	"net"
    12  	"net/rpc"
    13  	"sync"
    14  	"sync/atomic"
    15  	"time"
    16  
    17  	hclog "github.com/hashicorp/go-hclog"
    18  	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
    19  	"github.com/hernad/nomad/helper"
    20  	"github.com/hernad/nomad/helper/tlsutil"
    21  	"github.com/hernad/nomad/nomad/structs"
    22  	"github.com/hashicorp/yamux"
    23  )
    24  
    25  // NewClientCodec returns a new rpc.ClientCodec to be used to make RPC calls.
    26  func NewClientCodec(conn io.ReadWriteCloser) rpc.ClientCodec {
    27  	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle)
    28  }
    29  
    30  // NewServerCodec returns a new rpc.ServerCodec to be used to handle RPCs.
    31  func NewServerCodec(conn io.ReadWriteCloser) rpc.ServerCodec {
    32  	return msgpackrpc.NewCodecFromHandle(true, true, conn, structs.MsgpackHandle)
    33  }
    34  
    35  // streamClient is used to wrap a stream with an RPC client
    36  type StreamClient struct {
    37  	stream net.Conn
    38  	codec  rpc.ClientCodec
    39  }
    40  
    41  func (sc *StreamClient) Close() {
    42  	sc.stream.Close()
    43  	sc.codec.Close()
    44  }
    45  
    46  // Conn is a pooled connection to a Nomad server
    47  type Conn struct {
    48  	refCount    int32
    49  	shouldClose int32
    50  
    51  	addr     net.Addr
    52  	session  *yamux.Session
    53  	lastUsed atomic.Pointer[time.Time]
    54  
    55  	pool *ConnPool
    56  
    57  	clients    *list.List
    58  	clientLock sync.Mutex
    59  }
    60  
    61  // markForUse does all the bookkeeping required to ready a connection for use,
    62  // and ensure that active connections don't get reaped.
    63  func (c *Conn) markForUse() {
    64  	now := time.Now()
    65  	c.lastUsed.Store(&now)
    66  	atomic.AddInt32(&c.refCount, 1)
    67  }
    68  
    69  // releaseUse is the complement of `markForUse`, to free up the reference count
    70  func (c *Conn) releaseUse() {
    71  	refCount := atomic.AddInt32(&c.refCount, -1)
    72  	if refCount == 0 && atomic.LoadInt32(&c.shouldClose) == 1 {
    73  		c.Close()
    74  	}
    75  }
    76  
    77  func (c *Conn) Close() error {
    78  	return c.session.Close()
    79  }
    80  
    81  // getClient is used to get a cached or new client
    82  func (c *Conn) getRPCClient() (*StreamClient, error) {
    83  	// Check for cached client
    84  	c.clientLock.Lock()
    85  	front := c.clients.Front()
    86  	if front != nil {
    87  		c.clients.Remove(front)
    88  	}
    89  	c.clientLock.Unlock()
    90  	if front != nil {
    91  		return front.Value.(*StreamClient), nil
    92  	}
    93  
    94  	// Open a new session
    95  	stream, err := c.session.Open()
    96  	if err != nil {
    97  		return nil, err
    98  	}
    99  
   100  	if _, err := stream.Write([]byte{byte(RpcNomad)}); err != nil {
   101  		stream.Close()
   102  		return nil, err
   103  	}
   104  
   105  	// Create a client codec
   106  	codec := NewClientCodec(stream)
   107  
   108  	// Return a new stream client
   109  	sc := &StreamClient{
   110  		stream: stream,
   111  		codec:  codec,
   112  	}
   113  	return sc, nil
   114  }
   115  
   116  // returnClient is used when done with a stream
   117  // to allow re-use by a future RPC
   118  func (c *Conn) returnClient(client *StreamClient) {
   119  	didSave := false
   120  	c.clientLock.Lock()
   121  	if c.clients.Len() < c.pool.maxStreams && atomic.LoadInt32(&c.shouldClose) == 0 {
   122  		c.clients.PushFront(client)
   123  		didSave = true
   124  
   125  		// If this is a Yamux stream, shrink the internal buffers so that
   126  		// we can GC the idle memory
   127  		if ys, ok := client.stream.(*yamux.Stream); ok {
   128  			ys.Shrink()
   129  		}
   130  	}
   131  	c.clientLock.Unlock()
   132  	if !didSave {
   133  		client.Close()
   134  	}
   135  }
   136  
   137  func (c *Conn) IsClosed() bool {
   138  	return c.session.IsClosed()
   139  }
   140  
   141  func (c *Conn) AcceptStream() (net.Conn, error) {
   142  	s, err := c.session.AcceptStream()
   143  	if err != nil {
   144  		return nil, err
   145  	}
   146  
   147  	c.markForUse()
   148  	return &incomingStream{
   149  		Stream: s,
   150  		parent: c,
   151  	}, nil
   152  }
   153  
   154  // incomingStream wraps yamux.Stream but frees the underlying yamux.Session
   155  // when closed
   156  type incomingStream struct {
   157  	*yamux.Stream
   158  
   159  	parent *Conn
   160  }
   161  
   162  func (s *incomingStream) Close() error {
   163  	err := s.Stream.Close()
   164  
   165  	// always release parent even if error
   166  	s.parent.releaseUse()
   167  
   168  	return err
   169  }
   170  
   171  // ConnPool is used to maintain a connection pool to other
   172  // Nomad servers. This is used to reduce the latency of
   173  // RPC requests between servers. It is only used to pool
   174  // connections in the rpcNomad mode. Raft connections
   175  // are pooled separately.
   176  type ConnPool struct {
   177  	sync.Mutex
   178  
   179  	// logger is the logger to be used
   180  	logger *log.Logger
   181  
   182  	// The maximum time to keep a connection open
   183  	maxTime time.Duration
   184  
   185  	// The maximum number of open streams to keep
   186  	maxStreams int
   187  
   188  	// Pool maps an address to a open connection
   189  	pool map[string]*Conn
   190  
   191  	// limiter is used to throttle the number of connect attempts
   192  	// to a given address. The first thread will attempt a connection
   193  	// and put a channel in here, which all other threads will wait
   194  	// on to close.
   195  	limiter map[string]chan struct{}
   196  
   197  	// TLS wrapper
   198  	tlsWrap tlsutil.RegionWrapper
   199  
   200  	// Used to indicate the pool is shutdown
   201  	shutdown   bool
   202  	shutdownCh chan struct{}
   203  
   204  	// connListener is used to notify a potential listener of a new connection
   205  	// being made.
   206  	connListener chan<- *Conn
   207  }
   208  
   209  // NewPool is used to make a new connection pool
   210  // Maintain at most one connection per host, for up to maxTime.
   211  // Set maxTime to 0 to disable reaping. maxStreams is used to control
   212  // the number of idle streams allowed.
   213  // If TLS settings are provided outgoing connections use TLS.
   214  func NewPool(logger hclog.Logger, maxTime time.Duration, maxStreams int, tlsWrap tlsutil.RegionWrapper) *ConnPool {
   215  	pool := &ConnPool{
   216  		logger:     logger.StandardLogger(&hclog.StandardLoggerOptions{InferLevels: true}),
   217  		maxTime:    maxTime,
   218  		maxStreams: maxStreams,
   219  		pool:       make(map[string]*Conn),
   220  		limiter:    make(map[string]chan struct{}),
   221  		tlsWrap:    tlsWrap,
   222  		shutdownCh: make(chan struct{}),
   223  	}
   224  	if maxTime > 0 {
   225  		go pool.reap()
   226  	}
   227  	return pool
   228  }
   229  
   230  // Shutdown is used to close the connection pool
   231  func (p *ConnPool) Shutdown() error {
   232  	p.Lock()
   233  	defer p.Unlock()
   234  
   235  	for _, conn := range p.pool {
   236  		conn.Close()
   237  	}
   238  	p.pool = make(map[string]*Conn)
   239  
   240  	if p.shutdown {
   241  		return nil
   242  	}
   243  
   244  	if p.connListener != nil {
   245  		close(p.connListener)
   246  		p.connListener = nil
   247  	}
   248  
   249  	p.shutdown = true
   250  	close(p.shutdownCh)
   251  	return nil
   252  }
   253  
   254  // ReloadTLS reloads TLS configuration on the fly
   255  func (p *ConnPool) ReloadTLS(tlsWrap tlsutil.RegionWrapper) {
   256  	p.Lock()
   257  	defer p.Unlock()
   258  
   259  	oldPool := p.pool
   260  	for _, conn := range oldPool {
   261  		conn.Close()
   262  	}
   263  	p.pool = make(map[string]*Conn)
   264  	p.tlsWrap = tlsWrap
   265  }
   266  
   267  // SetConnListener is used to listen to new connections being made. The
   268  // channel will be closed when the conn pool is closed or a new listener is set.
   269  func (p *ConnPool) SetConnListener(l chan<- *Conn) {
   270  	p.Lock()
   271  	defer p.Unlock()
   272  
   273  	// Close the old listener
   274  	if p.connListener != nil {
   275  		close(p.connListener)
   276  	}
   277  
   278  	// Store the new listener
   279  	p.connListener = l
   280  }
   281  
   282  // Acquire is used to get a connection that is
   283  // pooled or to return a new connection
   284  func (p *ConnPool) acquire(region string, addr net.Addr) (*Conn, error) {
   285  	// Check to see if there's a pooled connection available. This is up
   286  	// here since it should the vastly more common case than the rest
   287  	// of the code here.
   288  	p.Lock()
   289  	c := p.pool[addr.String()]
   290  	if c != nil {
   291  		c.markForUse()
   292  		p.Unlock()
   293  		return c, nil
   294  	}
   295  
   296  	// If not (while we are still locked), set up the throttling structure
   297  	// for this address, which will make everyone else wait until our
   298  	// attempt is done.
   299  	var wait chan struct{}
   300  	var ok bool
   301  	if wait, ok = p.limiter[addr.String()]; !ok {
   302  		wait = make(chan struct{})
   303  		p.limiter[addr.String()] = wait
   304  	}
   305  	isLeadThread := !ok
   306  	p.Unlock()
   307  
   308  	// If we are the lead thread, make the new connection and then wake
   309  	// everybody else up to see if we got it.
   310  	if isLeadThread {
   311  		c, err := p.getNewConn(region, addr)
   312  		p.Lock()
   313  		delete(p.limiter, addr.String())
   314  		close(wait)
   315  		if err != nil {
   316  			p.Unlock()
   317  			return nil, err
   318  		}
   319  
   320  		p.pool[addr.String()] = c
   321  
   322  		// If there is a connection listener, notify them of the new connection.
   323  		if p.connListener != nil {
   324  			select {
   325  			case p.connListener <- c:
   326  			default:
   327  			}
   328  		}
   329  
   330  		p.Unlock()
   331  		return c, nil
   332  	}
   333  
   334  	// Otherwise, wait for the lead thread to attempt the connection
   335  	// and use what's in the pool at that point.
   336  	select {
   337  	case <-p.shutdownCh:
   338  		return nil, fmt.Errorf("rpc error: shutdown")
   339  	case <-wait:
   340  	}
   341  
   342  	// See if the lead thread was able to get us a connection.
   343  	p.Lock()
   344  	if c := p.pool[addr.String()]; c != nil {
   345  		c.markForUse()
   346  		p.Unlock()
   347  		return c, nil
   348  	}
   349  
   350  	p.Unlock()
   351  	return nil, fmt.Errorf("rpc error: lead thread didn't get connection")
   352  }
   353  
   354  // getNewConn is used to return a new connection
   355  func (p *ConnPool) getNewConn(region string, addr net.Addr) (*Conn, error) {
   356  	// Try to dial the conn
   357  	conn, err := net.DialTimeout("tcp", addr.String(), 10*time.Second)
   358  	if err != nil {
   359  		return nil, err
   360  	}
   361  
   362  	// Cast to TCPConn
   363  	if tcp, ok := conn.(*net.TCPConn); ok {
   364  		tcp.SetKeepAlive(true)
   365  		tcp.SetNoDelay(true)
   366  	}
   367  
   368  	// Check if TLS is enabled
   369  	if p.tlsWrap != nil {
   370  		// Switch the connection into TLS mode
   371  		if _, err := conn.Write([]byte{byte(RpcTLS)}); err != nil {
   372  			conn.Close()
   373  			return nil, err
   374  		}
   375  
   376  		// Wrap the connection in a TLS client
   377  		tlsConn, err := p.tlsWrap(region, conn)
   378  		if err != nil {
   379  			conn.Close()
   380  			return nil, err
   381  		}
   382  		conn = tlsConn
   383  	}
   384  
   385  	// Write the multiplex byte to set the mode
   386  	if _, err := conn.Write([]byte{byte(RpcMultiplexV2)}); err != nil {
   387  		conn.Close()
   388  		return nil, err
   389  	}
   390  
   391  	// Setup the logger
   392  	conf := yamux.DefaultConfig()
   393  	conf.LogOutput = nil
   394  	conf.Logger = p.logger
   395  
   396  	// Create a multiplexed session
   397  	session, err := yamux.Client(conn, conf)
   398  	if err != nil {
   399  		conn.Close()
   400  		return nil, err
   401  	}
   402  
   403  	// Wrap the connection
   404  	c := &Conn{
   405  		refCount: 1,
   406  		addr:     addr,
   407  		session:  session,
   408  		clients:  list.New(),
   409  		lastUsed: atomic.Pointer[time.Time]{},
   410  		pool:     p,
   411  	}
   412  
   413  	now := time.Now()
   414  	c.lastUsed.Store(&now)
   415  	return c, nil
   416  }
   417  
   418  // clearConn is used to clear any cached connection, potentially in response to
   419  // an error
   420  func (p *ConnPool) clearConn(conn *Conn) {
   421  	// Ensure returned streams are closed
   422  	atomic.StoreInt32(&conn.shouldClose, 1)
   423  
   424  	// Clear from the cache
   425  	p.Lock()
   426  	if c, ok := p.pool[conn.addr.String()]; ok && c == conn {
   427  		delete(p.pool, conn.addr.String())
   428  	}
   429  	p.Unlock()
   430  
   431  	// Close down immediately if idle
   432  	if refCount := atomic.LoadInt32(&conn.refCount); refCount == 0 {
   433  		conn.Close()
   434  	}
   435  }
   436  
   437  // getClient is used to get a usable client for an address
   438  func (p *ConnPool) getRPCClient(region string, addr net.Addr) (*Conn, *StreamClient, error) {
   439  	retries := 0
   440  START:
   441  	// Try to get a conn first
   442  	conn, err := p.acquire(region, addr)
   443  	if err != nil {
   444  		return nil, nil, fmt.Errorf("failed to get conn: %v", err)
   445  	}
   446  
   447  	// Get a client
   448  	client, err := conn.getRPCClient()
   449  	if err != nil {
   450  		p.clearConn(conn)
   451  		conn.releaseUse()
   452  
   453  		// Try to redial, possible that the TCP session closed due to timeout
   454  		if retries == 0 {
   455  			retries++
   456  			goto START
   457  		}
   458  		return nil, nil, fmt.Errorf("failed to start stream: %v", err)
   459  	}
   460  	return conn, client, nil
   461  }
   462  
   463  // StreamingRPC is used to make an streaming RPC call.  Callers must
   464  // close the connection when done.
   465  func (p *ConnPool) StreamingRPC(region string, addr net.Addr) (net.Conn, error) {
   466  	conn, err := p.acquire(region, addr)
   467  	if err != nil {
   468  		return nil, fmt.Errorf("failed to get conn: %v", err)
   469  	}
   470  
   471  	s, err := conn.session.Open()
   472  	if err != nil {
   473  		return nil, fmt.Errorf("failed to open a streaming connection: %v", err)
   474  	}
   475  
   476  	if _, err := s.Write([]byte{byte(RpcStreaming)}); err != nil {
   477  		conn.Close()
   478  		return nil, err
   479  	}
   480  
   481  	return s, nil
   482  }
   483  
   484  // RPC is used to make an RPC call to a remote host
   485  func (p *ConnPool) RPC(region string, addr net.Addr, method string, args interface{}, reply interface{}) error {
   486  	// Get a usable client
   487  	conn, sc, err := p.getRPCClient(region, addr)
   488  	if err != nil {
   489  		return fmt.Errorf("rpc error: %w", err)
   490  	}
   491  	defer conn.releaseUse()
   492  
   493  	// Make the RPC call
   494  	err = msgpackrpc.CallWithCodec(sc.codec, method, args, reply)
   495  	if err != nil {
   496  		sc.Close()
   497  
   498  		// If we read EOF, the session is toast. Clear it and open a
   499  		// new session next time
   500  		// See https://github.com/hashicorp/consul/blob/v1.6.3/agent/pool/pool.go#L471-L477
   501  		if helper.IsErrEOF(err) {
   502  			p.clearConn(conn)
   503  		}
   504  
   505  		// If the error is an RPC Coded error
   506  		// return the coded error without wrapping
   507  		if structs.IsErrRPCCoded(err) {
   508  			return err
   509  		}
   510  
   511  		// TODO wrap with RPCCoded error instead
   512  		return fmt.Errorf("rpc error: %w", err)
   513  	}
   514  
   515  	// Done with the connection
   516  	conn.returnClient(sc)
   517  	return nil
   518  }
   519  
   520  // Reap is used to close conns open over maxTime
   521  func (p *ConnPool) reap() {
   522  	for {
   523  		// Sleep for a while
   524  		select {
   525  		case <-p.shutdownCh:
   526  			return
   527  		case <-time.After(time.Second):
   528  		}
   529  
   530  		// Reap all old conns
   531  		p.Lock()
   532  		var removed []string
   533  		now := time.Now()
   534  		for host, conn := range p.pool {
   535  			// Skip recently used connections
   536  			if now.Sub(*conn.lastUsed.Load()) < p.maxTime {
   537  				continue
   538  			}
   539  
   540  			// Skip connections with active streams
   541  			if atomic.LoadInt32(&conn.refCount) > 0 {
   542  				continue
   543  			}
   544  
   545  			// Close the conn
   546  			conn.Close()
   547  
   548  			// Remove from pool
   549  			removed = append(removed, host)
   550  		}
   551  		for _, host := range removed {
   552  			delete(p.pool, host)
   553  		}
   554  		p.Unlock()
   555  	}
   556  }