github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/go.mongodb.org/mongo-driver/x/mongo/driver/topology/server.go (about)

     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at http://www.apache.org/licenses/LICENSE-2.0
     6  
     7  package topology
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"net"
    14  	"sync"
    15  	"sync/atomic"
    16  	"time"
    17  
    18  	"go.mongodb.org/mongo-driver/bson/primitive"
    19  	"go.mongodb.org/mongo-driver/event"
    20  	"go.mongodb.org/mongo-driver/mongo/address"
    21  	"go.mongodb.org/mongo-driver/mongo/description"
    22  	"go.mongodb.org/mongo-driver/x/mongo/driver"
    23  	"go.mongodb.org/mongo-driver/x/mongo/driver/operation"
    24  )
    25  
    26  const minHeartbeatInterval = 500 * time.Millisecond
    27  const wireVersion42 = 8 // Wire version for MongoDB 4.2
    28  
    29  // Server state constants.
    30  const (
    31  	serverDisconnected int64 = iota
    32  	serverDisconnecting
    33  	serverConnected
    34  )
    35  
    36  func serverStateString(state int64) string {
    37  	switch state {
    38  	case serverDisconnected:
    39  		return "Disconnected"
    40  	case serverDisconnecting:
    41  		return "Disconnecting"
    42  	case serverConnected:
    43  		return "Connected"
    44  	}
    45  
    46  	return ""
    47  }
    48  
    49  var (
    50  	// ErrServerClosed occurs when an attempt to Get a connection is made after
    51  	// the server has been closed.
    52  	ErrServerClosed = errors.New("server is closed")
    53  	// ErrServerConnected occurs when at attempt to Connect is made after a server
    54  	// has already been connected.
    55  	ErrServerConnected = errors.New("server is connected")
    56  
    57  	errCheckCancelled = errors.New("server check cancelled")
    58  	emptyDescription  = description.NewDefaultServer("")
    59  )
    60  
    61  // SelectedServer represents a specific server that was selected during server selection.
    62  // It contains the kind of the topology it was selected from.
    63  type SelectedServer struct {
    64  	*Server
    65  
    66  	Kind description.TopologyKind
    67  }
    68  
    69  // Description returns a description of the server as of the last heartbeat.
    70  func (ss *SelectedServer) Description() description.SelectedServer {
    71  	sdesc := ss.Server.Description()
    72  	return description.SelectedServer{
    73  		Server: sdesc,
    74  		Kind:   ss.Kind,
    75  	}
    76  }
    77  
    78  // Server is a single server within a topology.
    79  type Server struct {
    80  	// The following integer fields must be accessed using the atomic package and should be at the
    81  	// beginning of the struct.
    82  	// - atomic bug: https://pkg.go.dev/sync/atomic#pkg-note-BUG
    83  	// - suggested layout: https://go101.org/article/memory-layout.html
    84  
    85  	state          int64
    86  	operationCount int64
    87  
    88  	cfg     *serverConfig
    89  	address address.Address
    90  
    91  	// connection related fields
    92  	pool *pool
    93  
    94  	// goroutine management fields
    95  	done          chan struct{}
    96  	checkNow      chan struct{}
    97  	disconnecting chan struct{}
    98  	closewg       sync.WaitGroup
    99  
   100  	// description related fields
   101  	desc                   atomic.Value // holds a description.Server
   102  	updateTopologyCallback atomic.Value
   103  	topologyID             primitive.ObjectID
   104  
   105  	// subscriber related fields
   106  	subLock             sync.Mutex
   107  	subscribers         map[uint64]chan description.Server
   108  	currentSubscriberID uint64
   109  	subscriptionsClosed bool
   110  
   111  	// heartbeat and cancellation related fields
   112  	// globalCtx should be created in NewServer and cancelled in Disconnect to signal that the server is shutting down.
   113  	// heartbeatCtx should be used for individual heartbeats and should be a child of globalCtx so that it will be
   114  	// cancelled automatically during shutdown.
   115  	heartbeatLock      sync.Mutex
   116  	conn               *connection
   117  	globalCtx          context.Context
   118  	globalCtxCancel    context.CancelFunc
   119  	heartbeatCtx       context.Context
   120  	heartbeatCtxCancel context.CancelFunc
   121  
   122  	processErrorLock sync.Mutex
   123  	rttMonitor       *rttMonitor
   124  }
   125  
   126  // updateTopologyCallback is a callback used to create a server that should be called when the parent Topology instance
   127  // should be updated based on a new server description. The callback must return the server description that should be
   128  // stored by the server.
   129  type updateTopologyCallback func(description.Server) description.Server
   130  
   131  // ConnectServer creates a new Server and then initializes it using the
   132  // Connect method.
   133  func ConnectServer(addr address.Address, updateCallback updateTopologyCallback, topologyID primitive.ObjectID, opts ...ServerOption) (*Server, error) {
   134  	srvr := NewServer(addr, topologyID, opts...)
   135  	err := srvr.Connect(updateCallback)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	return srvr, nil
   140  }
   141  
   142  // NewServer creates a new server. The mongodb server at the address will be monitored
   143  // on an internal monitoring goroutine.
   144  func NewServer(addr address.Address, topologyID primitive.ObjectID, opts ...ServerOption) *Server {
   145  	cfg := newServerConfig(opts...)
   146  	globalCtx, globalCtxCancel := context.WithCancel(context.Background())
   147  	s := &Server{
   148  		state: serverDisconnected,
   149  
   150  		cfg:     cfg,
   151  		address: addr,
   152  
   153  		done:          make(chan struct{}),
   154  		checkNow:      make(chan struct{}, 1),
   155  		disconnecting: make(chan struct{}),
   156  
   157  		topologyID: topologyID,
   158  
   159  		subscribers:     make(map[uint64]chan description.Server),
   160  		globalCtx:       globalCtx,
   161  		globalCtxCancel: globalCtxCancel,
   162  	}
   163  	s.desc.Store(description.NewDefaultServer(addr))
   164  	rttCfg := &rttConfig{
   165  		interval:           cfg.heartbeatInterval,
   166  		minRTTWindow:       5 * time.Minute,
   167  		createConnectionFn: s.createConnection,
   168  		createOperationFn:  s.createBaseOperation,
   169  	}
   170  	s.rttMonitor = newRTTMonitor(rttCfg)
   171  
   172  	pc := poolConfig{
   173  		Address:          addr,
   174  		MinPoolSize:      cfg.minConns,
   175  		MaxPoolSize:      cfg.maxConns,
   176  		MaxConnecting:    cfg.maxConnecting,
   177  		MaxIdleTime:      cfg.poolMaxIdleTime,
   178  		MaintainInterval: cfg.poolMaintainInterval,
   179  		PoolMonitor:      cfg.poolMonitor,
   180  		Logger:           cfg.logger,
   181  		handshakeErrFn:   s.ProcessHandshakeError,
   182  	}
   183  
   184  	connectionOpts := copyConnectionOpts(cfg.connectionOpts)
   185  	s.pool = newPool(pc, connectionOpts...)
   186  	s.publishServerOpeningEvent(s.address)
   187  
   188  	return s
   189  }
   190  
   191  // Connect initializes the Server by starting background monitoring goroutines.
   192  // This method must be called before a Server can be used.
   193  func (s *Server) Connect(updateCallback updateTopologyCallback) error {
   194  	if !atomic.CompareAndSwapInt64(&s.state, serverDisconnected, serverConnected) {
   195  		return ErrServerConnected
   196  	}
   197  
   198  	desc := description.NewDefaultServer(s.address)
   199  	if s.cfg.loadBalanced {
   200  		// LBs automatically start off with kind LoadBalancer because there is no monitoring routine for state changes.
   201  		desc.Kind = description.LoadBalancer
   202  	}
   203  	s.desc.Store(desc)
   204  	s.updateTopologyCallback.Store(updateCallback)
   205  
   206  	if !s.cfg.monitoringDisabled && !s.cfg.loadBalanced {
   207  		s.rttMonitor.connect()
   208  		s.closewg.Add(1)
   209  		go s.update()
   210  	}
   211  
   212  	// The CMAP spec describes that pools should only be marked "ready" when the server description
   213  	// is updated to something other than "Unknown". However, we maintain the previous Server
   214  	// behavior here and immediately mark the pool as ready during Connect() to simplify and speed
   215  	// up the Client startup behavior. The risk of marking a pool as ready proactively during
   216  	// Connect() is that we could attempt to create connections to a server that was configured
   217  	// erroneously until the first server check or checkOut() failure occurs, when the SDAM error
   218  	// handler would transition the Server back to "Unknown" and set the pool to "paused".
   219  	return s.pool.ready()
   220  }
   221  
   222  // Disconnect closes sockets to the server referenced by this Server.
   223  // Subscriptions to this Server will be closed. Disconnect will shutdown
   224  // any monitoring goroutines, closeConnection the idle connection pool, and will
   225  // wait until all the in use connections have been returned to the connection
   226  // pool and are closed before returning. If the context expires via
   227  // cancellation, deadline, or timeout before the in use connections have been
   228  // returned, the in use connections will be closed, resulting in the failure of
   229  // any in flight read or write operations. If this method returns with no
   230  // errors, all connections associated with this Server have been closed.
   231  func (s *Server) Disconnect(ctx context.Context) error {
   232  	if !atomic.CompareAndSwapInt64(&s.state, serverConnected, serverDisconnecting) {
   233  		return ErrServerClosed
   234  	}
   235  
   236  	s.updateTopologyCallback.Store((updateTopologyCallback)(nil))
   237  
   238  	// Cancel the global context so any new contexts created from it will be automatically cancelled. Close the done
   239  	// channel so the update() routine will know that it can stop. Cancel any in-progress monitoring checks at the end.
   240  	// The done channel is closed before cancelling the check so the update routine() will immediately detect that it
   241  	// can stop rather than trying to create new connections until the read from done succeeds.
   242  	s.globalCtxCancel()
   243  	close(s.done)
   244  	s.cancelCheck()
   245  
   246  	s.rttMonitor.disconnect()
   247  	s.pool.close(ctx)
   248  
   249  	s.closewg.Wait()
   250  	atomic.StoreInt64(&s.state, serverDisconnected)
   251  
   252  	return nil
   253  }
   254  
   255  // Connection gets a connection to the server.
   256  func (s *Server) Connection(ctx context.Context) (driver.Connection, error) {
   257  	if atomic.LoadInt64(&s.state) != serverConnected {
   258  		return nil, ErrServerClosed
   259  	}
   260  
   261  	// Increment the operation count before calling checkOut to make sure that all connection
   262  	// requests are included in the operation count, including those in the wait queue. If we got an
   263  	// error instead of a connection, immediately decrement the operation count.
   264  	atomic.AddInt64(&s.operationCount, 1)
   265  	conn, err := s.pool.checkOut(ctx)
   266  	if err != nil {
   267  		atomic.AddInt64(&s.operationCount, -1)
   268  		return nil, err
   269  	}
   270  
   271  	return &Connection{
   272  		connection: conn,
   273  		cleanupServerFn: func() {
   274  			// Decrement the operation count whenever the caller is done with the connection. Note
   275  			// that cleanupServerFn() is not called while the connection is pinned to a cursor or
   276  			// transaction, so the operation count is not decremented until the cursor is closed or
   277  			// the transaction is committed or aborted. Use an int64 instead of a uint64 to mitigate
   278  			// the impact of any possible bugs that could cause the uint64 to underflow, which would
   279  			// make the server much less selectable.
   280  			atomic.AddInt64(&s.operationCount, -1)
   281  		},
   282  	}, nil
   283  }
   284  
   285  // ProcessHandshakeError implements SDAM error handling for errors that occur before a connection
   286  // finishes handshaking.
   287  func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint64, serviceID *primitive.ObjectID) {
   288  	// Ignore the error if the server is behind a load balancer but the service ID is unknown. This indicates that the
   289  	// error happened when dialing the connection or during the MongoDB handshake, so we don't know the service ID to
   290  	// use for clearing the pool.
   291  	if err == nil || s.cfg.loadBalanced && serviceID == nil {
   292  		return
   293  	}
   294  	// Ignore the error if the connection is stale.
   295  	if startingGenerationNumber < s.pool.generation.getGeneration(serviceID) {
   296  		return
   297  	}
   298  
   299  	// Unwrap any connection errors. If there is no wrapped connection error, then the error should
   300  	// not result in any Server state change (e.g. a command error from the database).
   301  	wrappedConnErr := unwrapConnectionError(err)
   302  	if wrappedConnErr == nil {
   303  		return
   304  	}
   305  
   306  	// Must hold the processErrorLock while updating the server description and clearing the pool.
   307  	// Not holding the lock leads to possible out-of-order processing of pool.clear() and
   308  	// pool.ready() calls from concurrent server description updates.
   309  	s.processErrorLock.Lock()
   310  	defer s.processErrorLock.Unlock()
   311  
   312  	// Since the only kind of ConnectionError we receive from pool.Get will be an initialization error, we should set
   313  	// the description.Server appropriately. The description should not have a TopologyVersion because the staleness
   314  	// checking logic above has already determined that this description is not stale.
   315  	s.updateDescription(description.NewServerFromError(s.address, wrappedConnErr, nil))
   316  	s.pool.clear(err, serviceID)
   317  	s.cancelCheck()
   318  }
   319  
   320  // Description returns a description of the server as of the last heartbeat.
   321  func (s *Server) Description() description.Server {
   322  	return s.desc.Load().(description.Server)
   323  }
   324  
   325  // SelectedDescription returns a description.SelectedServer with a Kind of
   326  // Single. This can be used when performing tasks like monitoring a batch
   327  // of servers and you want to run one off commands against those servers.
   328  func (s *Server) SelectedDescription() description.SelectedServer {
   329  	sdesc := s.Description()
   330  	return description.SelectedServer{
   331  		Server: sdesc,
   332  		Kind:   description.Single,
   333  	}
   334  }
   335  
   336  // Subscribe returns a ServerSubscription which has a channel on which all
   337  // updated server descriptions will be sent. The channel will have a buffer
   338  // size of one, and will be pre-populated with the current description.
   339  func (s *Server) Subscribe() (*ServerSubscription, error) {
   340  	if atomic.LoadInt64(&s.state) != serverConnected {
   341  		return nil, ErrSubscribeAfterClosed
   342  	}
   343  	ch := make(chan description.Server, 1)
   344  	ch <- s.desc.Load().(description.Server)
   345  
   346  	s.subLock.Lock()
   347  	defer s.subLock.Unlock()
   348  	if s.subscriptionsClosed {
   349  		return nil, ErrSubscribeAfterClosed
   350  	}
   351  	id := s.currentSubscriberID
   352  	s.subscribers[id] = ch
   353  	s.currentSubscriberID++
   354  
   355  	ss := &ServerSubscription{
   356  		C:  ch,
   357  		s:  s,
   358  		id: id,
   359  	}
   360  
   361  	return ss, nil
   362  }
   363  
   364  // RequestImmediateCheck will cause the server to send a heartbeat immediately
   365  // instead of waiting for the heartbeat timeout.
   366  func (s *Server) RequestImmediateCheck() {
   367  	select {
   368  	case s.checkNow <- struct{}{}:
   369  	default:
   370  	}
   371  }
   372  
   373  // getWriteConcernErrorForProcessing extracts a driver.WriteConcernError from the provided error. This function returns
   374  // (error, true) if the error is a WriteConcernError and the falls under the requirements for SDAM error
   375  // handling and (nil, false) otherwise.
   376  func getWriteConcernErrorForProcessing(err error) (*driver.WriteConcernError, bool) {
   377  	writeCmdErr, ok := err.(driver.WriteCommandError)
   378  	if !ok {
   379  		return nil, false
   380  	}
   381  
   382  	wcerr := writeCmdErr.WriteConcernError
   383  	if wcerr != nil && (wcerr.NodeIsRecovering() || wcerr.NotPrimary()) {
   384  		return wcerr, true
   385  	}
   386  	return nil, false
   387  }
   388  
   389  // ProcessError handles SDAM error handling and implements driver.ErrorProcessor.
   390  func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessErrorResult {
   391  	// Ignore nil errors.
   392  	if err == nil {
   393  		return driver.NoChange
   394  	}
   395  
   396  	// Ignore errors from stale connections because the error came from a previous generation of the
   397  	// connection pool. The root cause of the error has aleady been handled, which is what caused
   398  	// the pool generation to increment. Processing errors for stale connections could result in
   399  	// handling the same error root cause multiple times (e.g. a temporary network interrupt causing
   400  	// all connections to the same server to return errors).
   401  	if conn.Stale() {
   402  		return driver.NoChange
   403  	}
   404  
   405  	// Must hold the processErrorLock while updating the server description and clearing the pool.
   406  	// Not holding the lock leads to possible out-of-order processing of pool.clear() and
   407  	// pool.ready() calls from concurrent server description updates.
   408  	s.processErrorLock.Lock()
   409  	defer s.processErrorLock.Unlock()
   410  
   411  	// Get the wire version and service ID from the connection description because they will never
   412  	// change for the lifetime of a connection and can possibly be different between connections to
   413  	// the same server.
   414  	connDesc := conn.Description()
   415  	wireVersion := connDesc.WireVersion
   416  	serviceID := connDesc.ServiceID
   417  
   418  	// Get the topology version from the Server description because the Server description is
   419  	// updated by heartbeats and errors, so typically has a more up-to-date topology version.
   420  	serverDesc := s.desc.Load().(description.Server)
   421  	topologyVersion := serverDesc.TopologyVersion
   422  
   423  	// We don't currently update the Server topology version when we create new application
   424  	// connections, so it's possible for a connection's topology version to be newer than the
   425  	// Server's topology version. Pick the "newest" of the two topology versions.
   426  	// Technically a nil topology version on a new database response should be considered a new
   427  	// topology version and replace the Server's topology version. However, we don't know if the
   428  	// connection's topology version is based on a new or old database response, so we ignore a nil
   429  	// topology version on the connection for now.
   430  	//
   431  	// TODO(GODRIVER-2841): Remove this logic once we set the Server description when we create
   432  	// TODO application connections because then the Server's topology version will always be the
   433  	// TODO latest known.
   434  	if tv := connDesc.TopologyVersion; tv != nil && topologyVersion.CompareToIncoming(tv) < 0 {
   435  		topologyVersion = tv
   436  	}
   437  
   438  	// Invalidate server description if not primary or node recovering error occurs.
   439  	// These errors can be reported as a command error or a write concern error.
   440  	if cerr, ok := err.(driver.Error); ok && (cerr.NodeIsRecovering() || cerr.NotPrimary()) {
   441  		// Ignore errors that came from when the database was on a previous topology version.
   442  		if topologyVersion.CompareToIncoming(cerr.TopologyVersion) >= 0 {
   443  			return driver.NoChange
   444  		}
   445  
   446  		// updates description to unknown
   447  		s.updateDescription(description.NewServerFromError(s.address, err, cerr.TopologyVersion))
   448  		s.RequestImmediateCheck()
   449  
   450  		res := driver.ServerMarkedUnknown
   451  		// If the node is shutting down or is older than 4.2, we synchronously clear the pool
   452  		if cerr.NodeIsShuttingDown() || wireVersion == nil || wireVersion.Max < wireVersion42 {
   453  			res = driver.ConnectionPoolCleared
   454  			s.pool.clear(err, serviceID)
   455  		}
   456  
   457  		return res
   458  	}
   459  	if wcerr, ok := getWriteConcernErrorForProcessing(err); ok {
   460  		// Ignore errors that came from when the database was on a previous topology version.
   461  		if topologyVersion.CompareToIncoming(wcerr.TopologyVersion) >= 0 {
   462  			return driver.NoChange
   463  		}
   464  
   465  		// updates description to unknown
   466  		s.updateDescription(description.NewServerFromError(s.address, err, wcerr.TopologyVersion))
   467  		s.RequestImmediateCheck()
   468  
   469  		res := driver.ServerMarkedUnknown
   470  		// If the node is shutting down or is older than 4.2, we synchronously clear the pool
   471  		if wcerr.NodeIsShuttingDown() || wireVersion == nil || wireVersion.Max < wireVersion42 {
   472  			res = driver.ConnectionPoolCleared
   473  			s.pool.clear(err, serviceID)
   474  		}
   475  		return res
   476  	}
   477  
   478  	wrappedConnErr := unwrapConnectionError(err)
   479  	if wrappedConnErr == nil {
   480  		return driver.NoChange
   481  	}
   482  
   483  	// Ignore transient timeout errors.
   484  	if netErr, ok := wrappedConnErr.(net.Error); ok && netErr.Timeout() {
   485  		return driver.NoChange
   486  	}
   487  	if wrappedConnErr == context.Canceled || wrappedConnErr == context.DeadlineExceeded {
   488  		return driver.NoChange
   489  	}
   490  
   491  	// For a non-timeout network error, we clear the pool, set the description to Unknown, and cancel the in-progress
   492  	// monitoring check. The check is cancelled last to avoid a post-cancellation reconnect racing with
   493  	// updateDescription.
   494  	s.updateDescription(description.NewServerFromError(s.address, err, nil))
   495  	s.pool.clear(err, serviceID)
   496  	s.cancelCheck()
   497  	return driver.ConnectionPoolCleared
   498  }
   499  
   500  // update handles performing heartbeats and updating any subscribers of the
   501  // newest description.Server retrieved.
   502  func (s *Server) update() {
   503  	defer s.closewg.Done()
   504  	heartbeatTicker := time.NewTicker(s.cfg.heartbeatInterval)
   505  	rateLimiter := time.NewTicker(minHeartbeatInterval)
   506  	defer heartbeatTicker.Stop()
   507  	defer rateLimiter.Stop()
   508  	checkNow := s.checkNow
   509  	done := s.done
   510  
   511  	defer func() {
   512  		_ = recover()
   513  	}()
   514  
   515  	closeServer := func() {
   516  		s.subLock.Lock()
   517  		for id, c := range s.subscribers {
   518  			close(c)
   519  			delete(s.subscribers, id)
   520  		}
   521  		s.subscriptionsClosed = true
   522  		s.subLock.Unlock()
   523  
   524  		// We don't need to take s.heartbeatLock here because closeServer is called synchronously when the select checks
   525  		// below detect that the server is being closed, so we can be sure that the connection isn't being used.
   526  		if s.conn != nil {
   527  			_ = s.conn.close()
   528  		}
   529  	}
   530  
   531  	waitUntilNextCheck := func() {
   532  		// Wait until heartbeatFrequency elapses, an application operation requests an immediate check, or the server
   533  		// is disconnecting.
   534  		select {
   535  		case <-heartbeatTicker.C:
   536  		case <-checkNow:
   537  		case <-done:
   538  			// Return because the next update iteration will check the done channel again and clean up.
   539  			return
   540  		}
   541  
   542  		// Ensure we only return if minHeartbeatFrequency has elapsed or the server is disconnecting.
   543  		select {
   544  		case <-rateLimiter.C:
   545  		case <-done:
   546  			return
   547  		}
   548  	}
   549  
   550  	timeoutCnt := 0
   551  	for {
   552  		// Check if the server is disconnecting. Even if waitForNextCheck has already read from the done channel, we
   553  		// can safely read from it again because Disconnect closes the channel.
   554  		select {
   555  		case <-done:
   556  			closeServer()
   557  			return
   558  		default:
   559  		}
   560  
   561  		previousDescription := s.Description()
   562  
   563  		// Perform the next check.
   564  		desc, err := s.check()
   565  		if err == errCheckCancelled {
   566  			if atomic.LoadInt64(&s.state) != serverConnected {
   567  				continue
   568  			}
   569  
   570  			// If the server is not disconnecting, the check was cancelled by an application operation after an error.
   571  			// Wait before running the next check.
   572  			waitUntilNextCheck()
   573  			continue
   574  		}
   575  
   576  		if isShortcut := func() bool {
   577  			// Must hold the processErrorLock while updating the server description and clearing the
   578  			// pool. Not holding the lock leads to possible out-of-order processing of pool.clear() and
   579  			// pool.ready() calls from concurrent server description updates.
   580  			s.processErrorLock.Lock()
   581  			defer s.processErrorLock.Unlock()
   582  
   583  			s.updateDescription(desc)
   584  			// Retry after the first timeout before clearing the pool in case of a FAAS pause as
   585  			// described in GODRIVER-2577.
   586  			if err := unwrapConnectionError(desc.LastError); err != nil && timeoutCnt < 1 {
   587  				if err == context.Canceled || err == context.DeadlineExceeded {
   588  					timeoutCnt++
   589  					// We want to immediately retry on timeout error. Continue to next loop.
   590  					return true
   591  				}
   592  				if err, ok := err.(net.Error); ok && err.Timeout() {
   593  					timeoutCnt++
   594  					// We want to immediately retry on timeout error. Continue to next loop.
   595  					return true
   596  				}
   597  			}
   598  			if err := desc.LastError; err != nil {
   599  				// Clear the pool once the description has been updated to Unknown. Pass in a nil service ID to clear
   600  				// because the monitoring routine only runs for non-load balanced deployments in which servers don't return
   601  				// IDs.
   602  				s.pool.clear(err, nil)
   603  			}
   604  			// We're either not handling a timeout error, or we just handled the 2nd consecutive
   605  			// timeout error. In either case, reset the timeout count to 0 and return false to
   606  			// continue the normal check process.
   607  			timeoutCnt = 0
   608  			return false
   609  		}(); isShortcut {
   610  			continue
   611  		}
   612  
   613  		// If the server supports streaming or we're already streaming, we want to move to streaming the next response
   614  		// without waiting. If the server has transitioned to Unknown from a network error, we want to do another
   615  		// check without waiting in case it was a transient error and the server isn't actually down.
   616  		serverSupportsStreaming := desc.Kind != description.Unknown && desc.TopologyVersion != nil
   617  		connectionIsStreaming := s.conn != nil && s.conn.getCurrentlyStreaming()
   618  		transitionedFromNetworkError := desc.LastError != nil && unwrapConnectionError(desc.LastError) != nil &&
   619  			previousDescription.Kind != description.Unknown
   620  
   621  		if serverSupportsStreaming || connectionIsStreaming || transitionedFromNetworkError {
   622  			continue
   623  		}
   624  
   625  		// The server either does not support the streamable protocol or is not in a healthy state, so we wait until
   626  		// the next check.
   627  		waitUntilNextCheck()
   628  	}
   629  }
   630  
   631  // updateDescription handles updating the description on the Server, notifying
   632  // subscribers, and potentially draining the connection pool. The initial
   633  // parameter is used to determine if this is the first description from the
   634  // server.
   635  func (s *Server) updateDescription(desc description.Server) {
   636  	if s.cfg.loadBalanced {
   637  		// In load balanced mode, there are no updates from the monitoring routine. For errors encountered in pooled
   638  		// connections, the server should not be marked Unknown to ensure that the LB remains selectable.
   639  		return
   640  	}
   641  
   642  	defer func() {
   643  		//  ¯\_(ツ)_/¯
   644  		_ = recover()
   645  	}()
   646  
   647  	// Anytime we update the server description to something other than "unknown", set the pool to
   648  	// "ready". Do this before updating the description so that connections can be checked out as
   649  	// soon as the server is selectable. If the pool is already ready, this operation is a no-op.
   650  	// Note that this behavior is roughly consistent with the current Go driver behavior (connects
   651  	// to all servers, even non-data-bearing nodes) but deviates slightly from CMAP spec, which
   652  	// specifies a more restricted set of server descriptions and topologies that should mark the
   653  	// pool ready. We don't have access to the topology here, so prefer the current Go driver
   654  	// behavior for simplicity.
   655  	if desc.Kind != description.Unknown {
   656  		_ = s.pool.ready()
   657  	}
   658  
   659  	// Use the updateTopologyCallback to update the parent Topology and get the description that should be stored.
   660  	callback, ok := s.updateTopologyCallback.Load().(updateTopologyCallback)
   661  	if ok && callback != nil {
   662  		desc = callback(desc)
   663  	}
   664  	s.desc.Store(desc)
   665  
   666  	s.subLock.Lock()
   667  	for _, c := range s.subscribers {
   668  		select {
   669  		// drain the channel if it isn't empty
   670  		case <-c:
   671  		default:
   672  		}
   673  		c <- desc
   674  	}
   675  	s.subLock.Unlock()
   676  }
   677  
   678  // createConnection creates a new connection instance but does not call connect on it. The caller must call connect
   679  // before the connection can be used for network operations.
   680  func (s *Server) createConnection() *connection {
   681  	opts := copyConnectionOpts(s.cfg.connectionOpts)
   682  	opts = append(opts,
   683  		WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
   684  		WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
   685  		WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
   686  		// We override whatever handshaker is currently attached to the options with a basic
   687  		// one because need to make sure we don't do auth.
   688  		WithHandshaker(func(h Handshaker) Handshaker {
   689  			return operation.NewHello().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts).
   690  				ServerAPI(s.cfg.serverAPI)
   691  		}),
   692  		// Override any monitors specified in options with nil to avoid monitoring heartbeats.
   693  		WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return nil }),
   694  	)
   695  
   696  	return newConnection(s.address, opts...)
   697  }
   698  
   699  func copyConnectionOpts(opts []ConnectionOption) []ConnectionOption {
   700  	optsCopy := make([]ConnectionOption, len(opts))
   701  	copy(optsCopy, opts)
   702  	return optsCopy
   703  }
   704  
   705  func (s *Server) setupHeartbeatConnection() error {
   706  	conn := s.createConnection()
   707  
   708  	// Take the lock when assigning the context and connection because they're accessed by cancelCheck.
   709  	s.heartbeatLock.Lock()
   710  	if s.heartbeatCtxCancel != nil {
   711  		// Ensure the previous context is cancelled to avoid a leak.
   712  		s.heartbeatCtxCancel()
   713  	}
   714  	s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx)
   715  	s.conn = conn
   716  	s.heartbeatLock.Unlock()
   717  
   718  	return s.conn.connect(s.heartbeatCtx)
   719  }
   720  
   721  // cancelCheck cancels in-progress connection dials and reads. It does not set any fields on the server.
   722  func (s *Server) cancelCheck() {
   723  	var conn *connection
   724  
   725  	// Take heartbeatLock for mutual exclusion with the checks in the update function.
   726  	s.heartbeatLock.Lock()
   727  	if s.heartbeatCtx != nil {
   728  		s.heartbeatCtxCancel()
   729  	}
   730  	conn = s.conn
   731  	s.heartbeatLock.Unlock()
   732  
   733  	if conn == nil {
   734  		return
   735  	}
   736  
   737  	// If the connection exists, we need to wait for it to be connected because conn.connect() and
   738  	// conn.close() cannot be called concurrently. If the connection wasn't successfully opened, its
   739  	// state was set back to disconnected, so calling conn.close() will be a no-op.
   740  	conn.closeConnectContext()
   741  	conn.wait()
   742  	_ = conn.close()
   743  }
   744  
   745  func (s *Server) checkWasCancelled() bool {
   746  	return s.heartbeatCtx.Err() != nil
   747  }
   748  
   749  func (s *Server) createBaseOperation(conn driver.Connection) *operation.Hello {
   750  	return operation.
   751  		NewHello().
   752  		ClusterClock(s.cfg.clock).
   753  		Deployment(driver.SingleConnectionDeployment{conn}).
   754  		ServerAPI(s.cfg.serverAPI)
   755  }
   756  
   757  func (s *Server) check() (description.Server, error) {
   758  	var descPtr *description.Server
   759  	var err error
   760  	var duration time.Duration
   761  
   762  	start := time.Now()
   763  	if s.conn == nil || s.conn.closed() || s.checkWasCancelled() {
   764  		// Create a new connection if this is the first check, the connection was closed after an error during the previous
   765  		// check, or the previous check was cancelled.
   766  		if s.conn != nil {
   767  			s.publishServerHeartbeatStartedEvent(s.conn.ID(), false)
   768  		}
   769  		// Create a new connection and add it's handshake RTT as a sample.
   770  		err = s.setupHeartbeatConnection()
   771  		duration = time.Since(start)
   772  		if err == nil {
   773  			// Use the description from the connection handshake as the value for this check.
   774  			s.rttMonitor.addSample(s.conn.helloRTT)
   775  			descPtr = &s.conn.desc
   776  			if s.conn != nil {
   777  				s.publishServerHeartbeatSucceededEvent(s.conn.ID(), duration, s.conn.desc, false)
   778  			}
   779  		} else {
   780  			err = unwrapConnectionError(err)
   781  			if s.conn != nil {
   782  				s.publishServerHeartbeatFailedEvent(s.conn.ID(), duration, err, false)
   783  			}
   784  		}
   785  	} else {
   786  		// An existing connection is being used. Use the server description properties to execute the right heartbeat.
   787  
   788  		// Wrap conn in a type that implements driver.StreamerConnection.
   789  		heartbeatConn := initConnection{s.conn}
   790  		baseOperation := s.createBaseOperation(heartbeatConn)
   791  		previousDescription := s.Description()
   792  		streamable := previousDescription.TopologyVersion != nil
   793  
   794  		s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable)
   795  		switch {
   796  		case s.conn.getCurrentlyStreaming():
   797  			// The connection is already in a streaming state, so we stream the next response.
   798  			err = baseOperation.StreamResponse(s.heartbeatCtx, heartbeatConn)
   799  		case streamable:
   800  			// The server supports the streamable protocol. Set the socket timeout to
   801  			// connectTimeoutMS+heartbeatFrequencyMS and execute an awaitable hello request. Set conn.canStream so
   802  			// the wire message will advertise streaming support to the server.
   803  
   804  			// Calculation for maxAwaitTimeMS is taken from time.Duration.Milliseconds (added in Go 1.13).
   805  			maxAwaitTimeMS := int64(s.cfg.heartbeatInterval) / 1e6
   806  			// If connectTimeoutMS=0, the socket timeout should be infinite. Otherwise, it is connectTimeoutMS +
   807  			// heartbeatFrequencyMS to account for the fact that the query will block for heartbeatFrequencyMS
   808  			// server-side.
   809  			socketTimeout := s.cfg.heartbeatTimeout
   810  			if socketTimeout != 0 {
   811  				socketTimeout += s.cfg.heartbeatInterval
   812  			}
   813  			s.conn.setSocketTimeout(socketTimeout)
   814  			baseOperation = baseOperation.TopologyVersion(previousDescription.TopologyVersion).
   815  				MaxAwaitTimeMS(maxAwaitTimeMS)
   816  			s.conn.setCanStream(true)
   817  			err = baseOperation.Execute(s.heartbeatCtx)
   818  		default:
   819  			// The server doesn't support the awaitable protocol. Set the socket timeout to connectTimeoutMS and
   820  			// execute a regular heartbeat without any additional parameters.
   821  
   822  			s.conn.setSocketTimeout(s.cfg.heartbeatTimeout)
   823  			err = baseOperation.Execute(s.heartbeatCtx)
   824  		}
   825  		duration = time.Since(start)
   826  
   827  		if err == nil {
   828  			tempDesc := baseOperation.Result(s.address)
   829  			descPtr = &tempDesc
   830  			s.publishServerHeartbeatSucceededEvent(s.conn.ID(), duration, tempDesc, s.conn.getCurrentlyStreaming() || streamable)
   831  		} else {
   832  			// Close the connection here rather than below so we ensure we're not closing a connection that wasn't
   833  			// successfully created.
   834  			if s.conn != nil {
   835  				_ = s.conn.close()
   836  			}
   837  			s.publishServerHeartbeatFailedEvent(s.conn.ID(), duration, err, s.conn.getCurrentlyStreaming() || streamable)
   838  		}
   839  	}
   840  
   841  	if descPtr != nil {
   842  		// The check was successful. Set the average RTT and the 90th percentile RTT and return.
   843  		desc := *descPtr
   844  		desc = desc.SetAverageRTT(s.rttMonitor.EWMA())
   845  		desc.HeartbeatInterval = s.cfg.heartbeatInterval
   846  		return desc, nil
   847  	}
   848  
   849  	if s.checkWasCancelled() {
   850  		// If the previous check was cancelled, we don't want to clear the pool. Return a sentinel error so the caller
   851  		// will know that an actual error didn't occur.
   852  		return emptyDescription, errCheckCancelled
   853  	}
   854  
   855  	// An error occurred. We reset the RTT monitor for all errors and return an Unknown description. The pool must also
   856  	// be cleared, but only after the description has already been updated, so that is handled by the caller.
   857  	topologyVersion := extractTopologyVersion(err)
   858  	s.rttMonitor.reset()
   859  	return description.NewServerFromError(s.address, err, topologyVersion), nil
   860  }
   861  
   862  func extractTopologyVersion(err error) *description.TopologyVersion {
   863  	if ce, ok := err.(ConnectionError); ok {
   864  		err = ce.Wrapped
   865  	}
   866  
   867  	switch converted := err.(type) {
   868  	case driver.Error:
   869  		return converted.TopologyVersion
   870  	case driver.WriteCommandError:
   871  		if converted.WriteConcernError != nil {
   872  			return converted.WriteConcernError.TopologyVersion
   873  		}
   874  	}
   875  
   876  	return nil
   877  }
   878  
   879  // RTTMonitor returns this server's round-trip-time monitor.
   880  func (s *Server) RTTMonitor() driver.RTTMonitor {
   881  	return s.rttMonitor
   882  }
   883  
   884  // OperationCount returns the current number of in-progress operations for this server.
   885  func (s *Server) OperationCount() int64 {
   886  	return atomic.LoadInt64(&s.operationCount)
   887  }
   888  
   889  // String implements the Stringer interface.
   890  func (s *Server) String() string {
   891  	desc := s.Description()
   892  	state := atomic.LoadInt64(&s.state)
   893  	str := fmt.Sprintf("Addr: %s, Type: %s, State: %s",
   894  		s.address, desc.Kind, serverStateString(state))
   895  	if len(desc.Tags) != 0 {
   896  		str += fmt.Sprintf(", Tag sets: %s", desc.Tags)
   897  	}
   898  	if state == serverConnected {
   899  		str += fmt.Sprintf(", Average RTT: %s, Min RTT: %s", desc.AverageRTT, s.RTTMonitor().Min())
   900  	}
   901  	if desc.LastError != nil {
   902  		str += fmt.Sprintf(", Last error: %s", desc.LastError)
   903  	}
   904  
   905  	return str
   906  }
   907  
   908  // ServerSubscription represents a subscription to the description.Server updates for
   909  // a specific server.
   910  type ServerSubscription struct {
   911  	C  <-chan description.Server
   912  	s  *Server
   913  	id uint64
   914  }
   915  
   916  // Unsubscribe unsubscribes this ServerSubscription from updates and closes the
   917  // subscription channel.
   918  func (ss *ServerSubscription) Unsubscribe() error {
   919  	ss.s.subLock.Lock()
   920  	defer ss.s.subLock.Unlock()
   921  	if ss.s.subscriptionsClosed {
   922  		return nil
   923  	}
   924  
   925  	ch, ok := ss.s.subscribers[ss.id]
   926  	if !ok {
   927  		return nil
   928  	}
   929  
   930  	close(ch)
   931  	delete(ss.s.subscribers, ss.id)
   932  
   933  	return nil
   934  }
   935  
   936  // publishes a ServerOpeningEvent to indicate the server is being initialized
   937  func (s *Server) publishServerOpeningEvent(addr address.Address) {
   938  	if s == nil {
   939  		return
   940  	}
   941  
   942  	serverOpening := &event.ServerOpeningEvent{
   943  		Address:    addr,
   944  		TopologyID: s.topologyID,
   945  	}
   946  
   947  	if s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerOpening != nil {
   948  		s.cfg.serverMonitor.ServerOpening(serverOpening)
   949  	}
   950  }
   951  
   952  // publishes a ServerHeartbeatStartedEvent to indicate a hello command has started
   953  func (s *Server) publishServerHeartbeatStartedEvent(connectionID string, await bool) {
   954  	serverHeartbeatStarted := &event.ServerHeartbeatStartedEvent{
   955  		ConnectionID: connectionID,
   956  		Awaited:      await,
   957  	}
   958  
   959  	if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatStarted != nil {
   960  		s.cfg.serverMonitor.ServerHeartbeatStarted(serverHeartbeatStarted)
   961  	}
   962  }
   963  
   964  // publishes a ServerHeartbeatSucceededEvent to indicate hello has succeeded
   965  func (s *Server) publishServerHeartbeatSucceededEvent(connectionID string,
   966  	duration time.Duration,
   967  	desc description.Server,
   968  	await bool,
   969  ) {
   970  	serverHeartbeatSucceeded := &event.ServerHeartbeatSucceededEvent{
   971  		DurationNanos: duration.Nanoseconds(),
   972  		Duration:      duration,
   973  		Reply:         desc,
   974  		ConnectionID:  connectionID,
   975  		Awaited:       await,
   976  	}
   977  
   978  	if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatSucceeded != nil {
   979  		s.cfg.serverMonitor.ServerHeartbeatSucceeded(serverHeartbeatSucceeded)
   980  	}
   981  }
   982  
   983  // publishes a ServerHeartbeatFailedEvent to indicate hello has failed
   984  func (s *Server) publishServerHeartbeatFailedEvent(connectionID string,
   985  	duration time.Duration,
   986  	err error,
   987  	await bool,
   988  ) {
   989  	serverHeartbeatFailed := &event.ServerHeartbeatFailedEvent{
   990  		DurationNanos: duration.Nanoseconds(),
   991  		Duration:      duration,
   992  		Failure:       err,
   993  		ConnectionID:  connectionID,
   994  		Awaited:       await,
   995  	}
   996  
   997  	if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatFailed != nil {
   998  		s.cfg.serverMonitor.ServerHeartbeatFailed(serverHeartbeatFailed)
   999  	}
  1000  }
  1001  
  1002  // unwrapConnectionError returns the connection error wrapped by err, or nil if err does not wrap a connection error.
  1003  func unwrapConnectionError(err error) error {
  1004  	// This is essentially an implementation of errors.As to unwrap this error until we get a ConnectionError and then
  1005  	// return ConnectionError.Wrapped.
  1006  
  1007  	connErr, ok := err.(ConnectionError)
  1008  	if ok {
  1009  		return connErr.Wrapped
  1010  	}
  1011  
  1012  	driverErr, ok := err.(driver.Error)
  1013  	if !ok || !driverErr.NetworkError() {
  1014  		return nil
  1015  	}
  1016  
  1017  	connErr, ok = driverErr.Wrapped.(ConnectionError)
  1018  	if ok {
  1019  		return connErr.Wrapped
  1020  	}
  1021  
  1022  	return nil
  1023  }