
     1  // Copyright (C) MongoDB, Inc. 2017-present.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License"); you may
     4  // not use this file except in compliance with the License. You may obtain
     5  // a copy of the License at
     7  package topology
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"net"
    14  	"sync"
    15  	"sync/atomic"
    16  	"time"
    18  	""
    19  	""
    20  	""
    21  	""
    22  	""
    23  	""
    24  )
    26  const minHeartbeatInterval = 500 * time.Millisecond
    27  const wireVersion42 = 8 // Wire version for MongoDB 4.2
    29  // Server state constants.
    30  const (
    31  	serverDisconnected int64 = iota
    32  	serverDisconnecting
    33  	serverConnected
    34  )
    36  func serverStateString(state int64) string {
    37  	switch state {
    38  	case serverDisconnected:
    39  		return "Disconnected"
    40  	case serverDisconnecting:
    41  		return "Disconnecting"
    42  	case serverConnected:
    43  		return "Connected"
    44  	}
    46  	return ""
    47  }
    49  var (
    50  	// ErrServerClosed occurs when an attempt to Get a connection is made after
    51  	// the server has been closed.
    52  	ErrServerClosed = errors.New("server is closed")
    53  	// ErrServerConnected occurs when at attempt to Connect is made after a server
    54  	// has already been connected.
    55  	ErrServerConnected = errors.New("server is connected")
    57  	errCheckCancelled = errors.New("server check cancelled")
    58  	emptyDescription  = description.NewDefaultServer("")
    59  )
    61  // SelectedServer represents a specific server that was selected during server selection.
    62  // It contains the kind of the topology it was selected from.
    63  type SelectedServer struct {
    64  	*Server
    66  	Kind description.TopologyKind
    67  }
    69  // Description returns a description of the server as of the last heartbeat.
    70  func (ss *SelectedServer) Description() description.SelectedServer {
    71  	sdesc := ss.Server.Description()
    72  	return description.SelectedServer{
    73  		Server: sdesc,
    74  		Kind:   ss.Kind,
    75  	}
    76  }
    78  // Server is a single server within a topology.
    79  type Server struct {
    80  	// The following integer fields must be accessed using the atomic package and should be at the
    81  	// beginning of the struct.
    82  	// - atomic bug:
    83  	// - suggested layout:
    85  	state          int64
    86  	operationCount int64
    88  	cfg     *serverConfig
    89  	address address.Address
    91  	// connection related fields
    92  	pool *pool
    94  	// goroutine management fields
    95  	done          chan struct{}
    96  	checkNow      chan struct{}
    97  	disconnecting chan struct{}
    98  	closewg       sync.WaitGroup
   100  	// description related fields
   101  	desc                   atomic.Value // holds a description.Server
   102  	updateTopologyCallback atomic.Value
   103  	topologyID             primitive.ObjectID
   105  	// subscriber related fields
   106  	subLock             sync.Mutex
   107  	subscribers         map[uint64]chan description.Server
   108  	currentSubscriberID uint64
   109  	subscriptionsClosed bool
   111  	// heartbeat and cancellation related fields
   112  	// globalCtx should be created in NewServer and cancelled in Disconnect to signal that the server is shutting down.
   113  	// heartbeatCtx should be used for individual heartbeats and should be a child of globalCtx so that it will be
   114  	// cancelled automatically during shutdown.
   115  	heartbeatLock      sync.Mutex
   116  	conn               *connection
   117  	globalCtx          context.Context
   118  	globalCtxCancel    context.CancelFunc
   119  	heartbeatCtx       context.Context
   120  	heartbeatCtxCancel context.CancelFunc
   122  	processErrorLock sync.Mutex
   123  	rttMonitor       *rttMonitor
   124  }
   126  // updateTopologyCallback is a callback used to create a server that should be called when the parent Topology instance
   127  // should be updated based on a new server description. The callback must return the server description that should be
   128  // stored by the server.
   129  type updateTopologyCallback func(description.Server) description.Server
   131  // ConnectServer creates a new Server and then initializes it using the
   132  // Connect method.
   133  func ConnectServer(addr address.Address, updateCallback updateTopologyCallback, topologyID primitive.ObjectID, opts ...ServerOption) (*Server, error) {
   134  	srvr := NewServer(addr, topologyID, opts...)
   135  	err := srvr.Connect(updateCallback)
   136  	if err != nil {
   137  		return nil, err
   138  	}
   139  	return srvr, nil
   140  }
   142  // NewServer creates a new server. The mongodb server at the address will be monitored
   143  // on an internal monitoring goroutine.
   144  func NewServer(addr address.Address, topologyID primitive.ObjectID, opts ...ServerOption) *Server {
   145  	cfg := newServerConfig(opts...)
   146  	globalCtx, globalCtxCancel := context.WithCancel(context.Background())
   147  	s := &Server{
   148  		state: serverDisconnected,
   150  		cfg:     cfg,
   151  		address: addr,
   153  		done:          make(chan struct{}),
   154  		checkNow:      make(chan struct{}, 1),
   155  		disconnecting: make(chan struct{}),
   157  		topologyID: topologyID,
   159  		subscribers:     make(map[uint64]chan description.Server),
   160  		globalCtx:       globalCtx,
   161  		globalCtxCancel: globalCtxCancel,
   162  	}
   163  	s.desc.Store(description.NewDefaultServer(addr))
   164  	rttCfg := &rttConfig{
   165  		interval:           cfg.heartbeatInterval,
   166  		minRTTWindow:       5 * time.Minute,
   167  		createConnectionFn: s.createConnection,
   168  		createOperationFn:  s.createBaseOperation,
   169  	}
   170  	s.rttMonitor = newRTTMonitor(rttCfg)
   172  	pc := poolConfig{
   173  		Address:          addr,
   174  		MinPoolSize:      cfg.minConns,
   175  		MaxPoolSize:      cfg.maxConns,
   176  		MaxConnecting:    cfg.maxConnecting,
   177  		MaxIdleTime:      cfg.poolMaxIdleTime,
   178  		MaintainInterval: cfg.poolMaintainInterval,
   179  		PoolMonitor:      cfg.poolMonitor,
   180  		Logger:           cfg.logger,
   181  		handshakeErrFn:   s.ProcessHandshakeError,
   182  	}
   184  	connectionOpts := copyConnectionOpts(cfg.connectionOpts)
   185  	s.pool = newPool(pc, connectionOpts...)
   186  	s.publishServerOpeningEvent(s.address)
   188  	return s
   189  }
   191  // Connect initializes the Server by starting background monitoring goroutines.
   192  // This method must be called before a Server can be used.
   193  func (s *Server) Connect(updateCallback updateTopologyCallback) error {
   194  	if !atomic.CompareAndSwapInt64(&s.state, serverDisconnected, serverConnected) {
   195  		return ErrServerConnected
   196  	}
   198  	desc := description.NewDefaultServer(s.address)
   199  	if s.cfg.loadBalanced {
   200  		// LBs automatically start off with kind LoadBalancer because there is no monitoring routine for state changes.
   201  		desc.Kind = description.LoadBalancer
   202  	}
   203  	s.desc.Store(desc)
   204  	s.updateTopologyCallback.Store(updateCallback)
   206  	if !s.cfg.monitoringDisabled && !s.cfg.loadBalanced {
   207  		s.rttMonitor.connect()
   208  		s.closewg.Add(1)
   209  		go s.update()
   210  	}
   212  	// The CMAP spec describes that pools should only be marked "ready" when the server description
   213  	// is updated to something other than "Unknown". However, we maintain the previous Server
   214  	// behavior here and immediately mark the pool as ready during Connect() to simplify and speed
   215  	// up the Client startup behavior. The risk of marking a pool as ready proactively during
   216  	// Connect() is that we could attempt to create connections to a server that was configured
   217  	// erroneously until the first server check or checkOut() failure occurs, when the SDAM error
   218  	// handler would transition the Server back to "Unknown" and set the pool to "paused".
   219  	return s.pool.ready()
   220  }
   222  // Disconnect closes sockets to the server referenced by this Server.
   223  // Subscriptions to this Server will be closed. Disconnect will shutdown
   224  // any monitoring goroutines, closeConnection the idle connection pool, and will
   225  // wait until all the in use connections have been returned to the connection
   226  // pool and are closed before returning. If the context expires via
   227  // cancellation, deadline, or timeout before the in use connections have been
   228  // returned, the in use connections will be closed, resulting in the failure of
   229  // any in flight read or write operations. If this method returns with no
   230  // errors, all connections associated with this Server have been closed.
   231  func (s *Server) Disconnect(ctx context.Context) error {
   232  	if !atomic.CompareAndSwapInt64(&s.state, serverConnected, serverDisconnecting) {
   233  		return ErrServerClosed
   234  	}
   236  	s.updateTopologyCallback.Store((updateTopologyCallback)(nil))
   238  	// Cancel the global context so any new contexts created from it will be automatically cancelled. Close the done
   239  	// channel so the update() routine will know that it can stop. Cancel any in-progress monitoring checks at the end.
   240  	// The done channel is closed before cancelling the check so the update routine() will immediately detect that it
   241  	// can stop rather than trying to create new connections until the read from done succeeds.
   242  	s.globalCtxCancel()
   243  	close(s.done)
   244  	s.cancelCheck()
   246  	s.rttMonitor.disconnect()
   247  	s.pool.close(ctx)
   249  	s.closewg.Wait()
   250  	atomic.StoreInt64(&s.state, serverDisconnected)
   252  	return nil
   253  }
   255  // Connection gets a connection to the server.
   256  func (s *Server) Connection(ctx context.Context) (driver.Connection, error) {
   257  	if atomic.LoadInt64(&s.state) != serverConnected {
   258  		return nil, ErrServerClosed
   259  	}
   261  	// Increment the operation count before calling checkOut to make sure that all connection
   262  	// requests are included in the operation count, including those in the wait queue. If we got an
   263  	// error instead of a connection, immediately decrement the operation count.
   264  	atomic.AddInt64(&s.operationCount, 1)
   265  	conn, err := s.pool.checkOut(ctx)
   266  	if err != nil {
   267  		atomic.AddInt64(&s.operationCount, -1)
   268  		return nil, err
   269  	}
   271  	return &Connection{
   272  		connection: conn,
   273  		cleanupServerFn: func() {
   274  			// Decrement the operation count whenever the caller is done with the connection. Note
   275  			// that cleanupServerFn() is not called while the connection is pinned to a cursor or
   276  			// transaction, so the operation count is not decremented until the cursor is closed or
   277  			// the transaction is committed or aborted. Use an int64 instead of a uint64 to mitigate
   278  			// the impact of any possible bugs that could cause the uint64 to underflow, which would
   279  			// make the server much less selectable.
   280  			atomic.AddInt64(&s.operationCount, -1)
   281  		},
   282  	}, nil
   283  }
   285  // ProcessHandshakeError implements SDAM error handling for errors that occur before a connection
   286  // finishes handshaking.
   287  func (s *Server) ProcessHandshakeError(err error, startingGenerationNumber uint64, serviceID *primitive.ObjectID) {
   288  	// Ignore the error if the server is behind a load balancer but the service ID is unknown. This indicates that the
   289  	// error happened when dialing the connection or during the MongoDB handshake, so we don't know the service ID to
   290  	// use for clearing the pool.
   291  	if err == nil || s.cfg.loadBalanced && serviceID == nil {
   292  		return
   293  	}
   294  	// Ignore the error if the connection is stale.
   295  	if startingGenerationNumber < s.pool.generation.getGeneration(serviceID) {
   296  		return
   297  	}
   299  	// Unwrap any connection errors. If there is no wrapped connection error, then the error should
   300  	// not result in any Server state change (e.g. a command error from the database).
   301  	wrappedConnErr := unwrapConnectionError(err)
   302  	if wrappedConnErr == nil {
   303  		return
   304  	}
   306  	// Must hold the processErrorLock while updating the server description and clearing the pool.
   307  	// Not holding the lock leads to possible out-of-order processing of pool.clear() and
   308  	// pool.ready() calls from concurrent server description updates.
   309  	s.processErrorLock.Lock()
   310  	defer s.processErrorLock.Unlock()
   312  	// Since the only kind of ConnectionError we receive from pool.Get will be an initialization error, we should set
   313  	// the description.Server appropriately. The description should not have a TopologyVersion because the staleness
   314  	// checking logic above has already determined that this description is not stale.
   315  	s.updateDescription(description.NewServerFromError(s.address, wrappedConnErr, nil))
   316  	s.pool.clear(err, serviceID)
   317  	s.cancelCheck()
   318  }
   320  // Description returns a description of the server as of the last heartbeat.
   321  func (s *Server) Description() description.Server {
   322  	return s.desc.Load().(description.Server)
   323  }
   325  // SelectedDescription returns a description.SelectedServer with a Kind of
   326  // Single. This can be used when performing tasks like monitoring a batch
   327  // of servers and you want to run one off commands against those servers.
   328  func (s *Server) SelectedDescription() description.SelectedServer {
   329  	sdesc := s.Description()
   330  	return description.SelectedServer{
   331  		Server: sdesc,
   332  		Kind:   description.Single,
   333  	}
   334  }
   336  // Subscribe returns a ServerSubscription which has a channel on which all
   337  // updated server descriptions will be sent. The channel will have a buffer
   338  // size of one, and will be pre-populated with the current description.
   339  func (s *Server) Subscribe() (*ServerSubscription, error) {
   340  	if atomic.LoadInt64(&s.state) != serverConnected {
   341  		return nil, ErrSubscribeAfterClosed
   342  	}
   343  	ch := make(chan description.Server, 1)
   344  	ch <- s.desc.Load().(description.Server)
   346  	s.subLock.Lock()
   347  	defer s.subLock.Unlock()
   348  	if s.subscriptionsClosed {
   349  		return nil, ErrSubscribeAfterClosed
   350  	}
   351  	id := s.currentSubscriberID
   352  	s.subscribers[id] = ch
   353  	s.currentSubscriberID++
   355  	ss := &ServerSubscription{
   356  		C:  ch,
   357  		s:  s,
   358  		id: id,
   359  	}
   361  	return ss, nil
   362  }
   364  // RequestImmediateCheck will cause the server to send a heartbeat immediately
   365  // instead of waiting for the heartbeat timeout.
   366  func (s *Server) RequestImmediateCheck() {
   367  	select {
   368  	case s.checkNow <- struct{}{}:
   369  	default:
   370  	}
   371  }
   373  // getWriteConcernErrorForProcessing extracts a driver.WriteConcernError from the provided error. This function returns
   374  // (error, true) if the error is a WriteConcernError and the falls under the requirements for SDAM error
   375  // handling and (nil, false) otherwise.
   376  func getWriteConcernErrorForProcessing(err error) (*driver.WriteConcernError, bool) {
   377  	writeCmdErr, ok := err.(driver.WriteCommandError)
   378  	if !ok {
   379  		return nil, false
   380  	}
   382  	wcerr := writeCmdErr.WriteConcernError
   383  	if wcerr != nil && (wcerr.NodeIsRecovering() || wcerr.NotPrimary()) {
   384  		return wcerr, true
   385  	}
   386  	return nil, false
   387  }
   389  // ProcessError handles SDAM error handling and implements driver.ErrorProcessor.
   390  func (s *Server) ProcessError(err error, conn driver.Connection) driver.ProcessErrorResult {
   391  	// Ignore nil errors.
   392  	if err == nil {
   393  		return driver.NoChange
   394  	}
   396  	// Ignore errors from stale connections because the error came from a previous generation of the
   397  	// connection pool. The root cause of the error has aleady been handled, which is what caused
   398  	// the pool generation to increment. Processing errors for stale connections could result in
   399  	// handling the same error root cause multiple times (e.g. a temporary network interrupt causing
   400  	// all connections to the same server to return errors).
   401  	if conn.Stale() {
   402  		return driver.NoChange
   403  	}
   405  	// Must hold the processErrorLock while updating the server description and clearing the pool.
   406  	// Not holding the lock leads to possible out-of-order processing of pool.clear() and
   407  	// pool.ready() calls from concurrent server description updates.
   408  	s.processErrorLock.Lock()
   409  	defer s.processErrorLock.Unlock()
   411  	// Get the wire version and service ID from the connection description because they will never
   412  	// change for the lifetime of a connection and can possibly be different between connections to
   413  	// the same server.
   414  	connDesc := conn.Description()
   415  	wireVersion := connDesc.WireVersion
   416  	serviceID := connDesc.ServiceID
   418  	// Get the topology version from the Server description because the Server description is
   419  	// updated by heartbeats and errors, so typically has a more up-to-date topology version.
   420  	serverDesc := s.desc.Load().(description.Server)
   421  	topologyVersion := serverDesc.TopologyVersion
   423  	// We don't currently update the Server topology version when we create new application
   424  	// connections, so it's possible for a connection's topology version to be newer than the
   425  	// Server's topology version. Pick the "newest" of the two topology versions.
   426  	// Technically a nil topology version on a new database response should be considered a new
   427  	// topology version and replace the Server's topology version. However, we don't know if the
   428  	// connection's topology version is based on a new or old database response, so we ignore a nil
   429  	// topology version on the connection for now.
   430  	//
   431  	// TODO(GODRIVER-2841): Remove this logic once we set the Server description when we create
   432  	// TODO application connections because then the Server's topology version will always be the
   433  	// TODO latest known.
   434  	if tv := connDesc.TopologyVersion; tv != nil && topologyVersion.CompareToIncoming(tv) < 0 {
   435  		topologyVersion = tv
   436  	}
   438  	// Invalidate server description if not primary or node recovering error occurs.
   439  	// These errors can be reported as a command error or a write concern error.
   440  	if cerr, ok := err.(driver.Error); ok && (cerr.NodeIsRecovering() || cerr.NotPrimary()) {
   441  		// Ignore errors that came from when the database was on a previous topology version.
   442  		if topologyVersion.CompareToIncoming(cerr.TopologyVersion) >= 0 {
   443  			return driver.NoChange
   444  		}
   446  		// updates description to unknown
   447  		s.updateDescription(description.NewServerFromError(s.address, err, cerr.TopologyVersion))
   448  		s.RequestImmediateCheck()
   450  		res := driver.ServerMarkedUnknown
   451  		// If the node is shutting down or is older than 4.2, we synchronously clear the pool
   452  		if cerr.NodeIsShuttingDown() || wireVersion == nil || wireVersion.Max < wireVersion42 {
   453  			res = driver.ConnectionPoolCleared
   454  			s.pool.clear(err, serviceID)
   455  		}
   457  		return res
   458  	}
   459  	if wcerr, ok := getWriteConcernErrorForProcessing(err); ok {
   460  		// Ignore errors that came from when the database was on a previous topology version.
   461  		if topologyVersion.CompareToIncoming(wcerr.TopologyVersion) >= 0 {
   462  			return driver.NoChange
   463  		}
   465  		// updates description to unknown
   466  		s.updateDescription(description.NewServerFromError(s.address, err, wcerr.TopologyVersion))
   467  		s.RequestImmediateCheck()
   469  		res := driver.ServerMarkedUnknown
   470  		// If the node is shutting down or is older than 4.2, we synchronously clear the pool
   471  		if wcerr.NodeIsShuttingDown() || wireVersion == nil || wireVersion.Max < wireVersion42 {
   472  			res = driver.ConnectionPoolCleared
   473  			s.pool.clear(err, serviceID)
   474  		}
   475  		return res
   476  	}
   478  	wrappedConnErr := unwrapConnectionError(err)
   479  	if wrappedConnErr == nil {
   480  		return driver.NoChange
   481  	}
   483  	// Ignore transient timeout errors.
   484  	if netErr, ok := wrappedConnErr.(net.Error); ok && netErr.Timeout() {
   485  		return driver.NoChange
   486  	}
   487  	if wrappedConnErr == context.Canceled || wrappedConnErr == context.DeadlineExceeded {
   488  		return driver.NoChange
   489  	}
   491  	// For a non-timeout network error, we clear the pool, set the description to Unknown, and cancel the in-progress
   492  	// monitoring check. The check is cancelled last to avoid a post-cancellation reconnect racing with
   493  	// updateDescription.
   494  	s.updateDescription(description.NewServerFromError(s.address, err, nil))
   495  	s.pool.clear(err, serviceID)
   496  	s.cancelCheck()
   497  	return driver.ConnectionPoolCleared
   498  }
   500  // update handles performing heartbeats and updating any subscribers of the
   501  // newest description.Server retrieved.
   502  func (s *Server) update() {
   503  	defer s.closewg.Done()
   504  	heartbeatTicker := time.NewTicker(s.cfg.heartbeatInterval)
   505  	rateLimiter := time.NewTicker(minHeartbeatInterval)
   506  	defer heartbeatTicker.Stop()
   507  	defer rateLimiter.Stop()
   508  	checkNow := s.checkNow
   509  	done := s.done
   511  	defer func() {
   512  		_ = recover()
   513  	}()
   515  	closeServer := func() {
   516  		s.subLock.Lock()
   517  		for id, c := range s.subscribers {
   518  			close(c)
   519  			delete(s.subscribers, id)
   520  		}
   521  		s.subscriptionsClosed = true
   522  		s.subLock.Unlock()
   524  		// We don't need to take s.heartbeatLock here because closeServer is called synchronously when the select checks
   525  		// below detect that the server is being closed, so we can be sure that the connection isn't being used.
   526  		if s.conn != nil {
   527  			_ = s.conn.close()
   528  		}
   529  	}
   531  	waitUntilNextCheck := func() {
   532  		// Wait until heartbeatFrequency elapses, an application operation requests an immediate check, or the server
   533  		// is disconnecting.
   534  		select {
   535  		case <-heartbeatTicker.C:
   536  		case <-checkNow:
   537  		case <-done:
   538  			// Return because the next update iteration will check the done channel again and clean up.
   539  			return
   540  		}
   542  		// Ensure we only return if minHeartbeatFrequency has elapsed or the server is disconnecting.
   543  		select {
   544  		case <-rateLimiter.C:
   545  		case <-done:
   546  			return
   547  		}
   548  	}
   550  	timeoutCnt := 0
   551  	for {
   552  		// Check if the server is disconnecting. Even if waitForNextCheck has already read from the done channel, we
   553  		// can safely read from it again because Disconnect closes the channel.
   554  		select {
   555  		case <-done:
   556  			closeServer()
   557  			return
   558  		default:
   559  		}
   561  		previousDescription := s.Description()
   563  		// Perform the next check.
   564  		desc, err := s.check()
   565  		if err == errCheckCancelled {
   566  			if atomic.LoadInt64(&s.state) != serverConnected {
   567  				continue
   568  			}
   570  			// If the server is not disconnecting, the check was cancelled by an application operation after an error.
   571  			// Wait before running the next check.
   572  			waitUntilNextCheck()
   573  			continue
   574  		}
   576  		if isShortcut := func() bool {
   577  			// Must hold the processErrorLock while updating the server description and clearing the
   578  			// pool. Not holding the lock leads to possible out-of-order processing of pool.clear() and
   579  			// pool.ready() calls from concurrent server description updates.
   580  			s.processErrorLock.Lock()
   581  			defer s.processErrorLock.Unlock()
   583  			s.updateDescription(desc)
   584  			// Retry after the first timeout before clearing the pool in case of a FAAS pause as
   585  			// described in GODRIVER-2577.
   586  			if err := unwrapConnectionError(desc.LastError); err != nil && timeoutCnt < 1 {
   587  				if err == context.Canceled || err == context.DeadlineExceeded {
   588  					timeoutCnt++
   589  					// We want to immediately retry on timeout error. Continue to next loop.
   590  					return true
   591  				}
   592  				if err, ok := err.(net.Error); ok && err.Timeout() {
   593  					timeoutCnt++
   594  					// We want to immediately retry on timeout error. Continue to next loop.
   595  					return true
   596  				}
   597  			}
   598  			if err := desc.LastError; err != nil {
   599  				// Clear the pool once the description has been updated to Unknown. Pass in a nil service ID to clear
   600  				// because the monitoring routine only runs for non-load balanced deployments in which servers don't return
   601  				// IDs.
   602  				s.pool.clear(err, nil)
   603  			}
   604  			// We're either not handling a timeout error, or we just handled the 2nd consecutive
   605  			// timeout error. In either case, reset the timeout count to 0 and return false to
   606  			// continue the normal check process.
   607  			timeoutCnt = 0
   608  			return false
   609  		}(); isShortcut {
   610  			continue
   611  		}
   613  		// If the server supports streaming or we're already streaming, we want to move to streaming the next response
   614  		// without waiting. If the server has transitioned to Unknown from a network error, we want to do another
   615  		// check without waiting in case it was a transient error and the server isn't actually down.
   616  		serverSupportsStreaming := desc.Kind != description.Unknown && desc.TopologyVersion != nil
   617  		connectionIsStreaming := s.conn != nil && s.conn.getCurrentlyStreaming()
   618  		transitionedFromNetworkError := desc.LastError != nil && unwrapConnectionError(desc.LastError) != nil &&
   619  			previousDescription.Kind != description.Unknown
   621  		if serverSupportsStreaming || connectionIsStreaming || transitionedFromNetworkError {
   622  			continue
   623  		}
   625  		// The server either does not support the streamable protocol or is not in a healthy state, so we wait until
   626  		// the next check.
   627  		waitUntilNextCheck()
   628  	}
   629  }
   631  // updateDescription handles updating the description on the Server, notifying
   632  // subscribers, and potentially draining the connection pool. The initial
   633  // parameter is used to determine if this is the first description from the
   634  // server.
   635  func (s *Server) updateDescription(desc description.Server) {
   636  	if s.cfg.loadBalanced {
   637  		// In load balanced mode, there are no updates from the monitoring routine. For errors encountered in pooled
   638  		// connections, the server should not be marked Unknown to ensure that the LB remains selectable.
   639  		return
   640  	}
   642  	defer func() {
   643  		//  ¯\_(ツ)_/¯
   644  		_ = recover()
   645  	}()
   647  	// Anytime we update the server description to something other than "unknown", set the pool to
   648  	// "ready". Do this before updating the description so that connections can be checked out as
   649  	// soon as the server is selectable. If the pool is already ready, this operation is a no-op.
   650  	// Note that this behavior is roughly consistent with the current Go driver behavior (connects
   651  	// to all servers, even non-data-bearing nodes) but deviates slightly from CMAP spec, which
   652  	// specifies a more restricted set of server descriptions and topologies that should mark the
   653  	// pool ready. We don't have access to the topology here, so prefer the current Go driver
   654  	// behavior for simplicity.
   655  	if desc.Kind != description.Unknown {
   656  		_ = s.pool.ready()
   657  	}
   659  	// Use the updateTopologyCallback to update the parent Topology and get the description that should be stored.
   660  	callback, ok := s.updateTopologyCallback.Load().(updateTopologyCallback)
   661  	if ok && callback != nil {
   662  		desc = callback(desc)
   663  	}
   664  	s.desc.Store(desc)
   666  	s.subLock.Lock()
   667  	for _, c := range s.subscribers {
   668  		select {
   669  		// drain the channel if it isn't empty
   670  		case <-c:
   671  		default:
   672  		}
   673  		c <- desc
   674  	}
   675  	s.subLock.Unlock()
   676  }
   678  // createConnection creates a new connection instance but does not call connect on it. The caller must call connect
   679  // before the connection can be used for network operations.
   680  func (s *Server) createConnection() *connection {
   681  	opts := copyConnectionOpts(s.cfg.connectionOpts)
   682  	opts = append(opts,
   683  		WithConnectTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
   684  		WithReadTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
   685  		WithWriteTimeout(func(time.Duration) time.Duration { return s.cfg.heartbeatTimeout }),
   686  		// We override whatever handshaker is currently attached to the options with a basic
   687  		// one because need to make sure we don't do auth.
   688  		WithHandshaker(func(h Handshaker) Handshaker {
   689  			return operation.NewHello().AppName(s.cfg.appname).Compressors(s.cfg.compressionOpts).
   690  				ServerAPI(s.cfg.serverAPI)
   691  		}),
   692  		// Override any monitors specified in options with nil to avoid monitoring heartbeats.
   693  		WithMonitor(func(*event.CommandMonitor) *event.CommandMonitor { return nil }),
   694  	)
   696  	return newConnection(s.address, opts...)
   697  }
   699  func copyConnectionOpts(opts []ConnectionOption) []ConnectionOption {
   700  	optsCopy := make([]ConnectionOption, len(opts))
   701  	copy(optsCopy, opts)
   702  	return optsCopy
   703  }
   705  func (s *Server) setupHeartbeatConnection() error {
   706  	conn := s.createConnection()
   708  	// Take the lock when assigning the context and connection because they're accessed by cancelCheck.
   709  	s.heartbeatLock.Lock()
   710  	if s.heartbeatCtxCancel != nil {
   711  		// Ensure the previous context is cancelled to avoid a leak.
   712  		s.heartbeatCtxCancel()
   713  	}
   714  	s.heartbeatCtx, s.heartbeatCtxCancel = context.WithCancel(s.globalCtx)
   715  	s.conn = conn
   716  	s.heartbeatLock.Unlock()
   718  	return s.conn.connect(s.heartbeatCtx)
   719  }
   721  // cancelCheck cancels in-progress connection dials and reads. It does not set any fields on the server.
   722  func (s *Server) cancelCheck() {
   723  	var conn *connection
   725  	// Take heartbeatLock for mutual exclusion with the checks in the update function.
   726  	s.heartbeatLock.Lock()
   727  	if s.heartbeatCtx != nil {
   728  		s.heartbeatCtxCancel()
   729  	}
   730  	conn = s.conn
   731  	s.heartbeatLock.Unlock()
   733  	if conn == nil {
   734  		return
   735  	}
   737  	// If the connection exists, we need to wait for it to be connected because conn.connect() and
   738  	// conn.close() cannot be called concurrently. If the connection wasn't successfully opened, its
   739  	// state was set back to disconnected, so calling conn.close() will be a no-op.
   740  	conn.closeConnectContext()
   741  	conn.wait()
   742  	_ = conn.close()
   743  }
   745  func (s *Server) checkWasCancelled() bool {
   746  	return s.heartbeatCtx.Err() != nil
   747  }
   749  func (s *Server) createBaseOperation(conn driver.Connection) *operation.Hello {
   750  	return operation.
   751  		NewHello().
   752  		ClusterClock(s.cfg.clock).
   753  		Deployment(driver.SingleConnectionDeployment{conn}).
   754  		ServerAPI(s.cfg.serverAPI)
   755  }
   757  func (s *Server) check() (description.Server, error) {
   758  	var descPtr *description.Server
   759  	var err error
   760  	var duration time.Duration
   762  	start := time.Now()
   763  	if s.conn == nil || s.conn.closed() || s.checkWasCancelled() {
   764  		// Create a new connection if this is the first check, the connection was closed after an error during the previous
   765  		// check, or the previous check was cancelled.
   766  		if s.conn != nil {
   767  			s.publishServerHeartbeatStartedEvent(s.conn.ID(), false)
   768  		}
   769  		// Create a new connection and add it's handshake RTT as a sample.
   770  		err = s.setupHeartbeatConnection()
   771  		duration = time.Since(start)
   772  		if err == nil {
   773  			// Use the description from the connection handshake as the value for this check.
   774  			s.rttMonitor.addSample(s.conn.helloRTT)
   775  			descPtr = &s.conn.desc
   776  			if s.conn != nil {
   777  				s.publishServerHeartbeatSucceededEvent(s.conn.ID(), duration, s.conn.desc, false)
   778  			}
   779  		} else {
   780  			err = unwrapConnectionError(err)
   781  			if s.conn != nil {
   782  				s.publishServerHeartbeatFailedEvent(s.conn.ID(), duration, err, false)
   783  			}
   784  		}
   785  	} else {
   786  		// An existing connection is being used. Use the server description properties to execute the right heartbeat.
   788  		// Wrap conn in a type that implements driver.StreamerConnection.
   789  		heartbeatConn := initConnection{s.conn}
   790  		baseOperation := s.createBaseOperation(heartbeatConn)
   791  		previousDescription := s.Description()
   792  		streamable := previousDescription.TopologyVersion != nil
   794  		s.publishServerHeartbeatStartedEvent(s.conn.ID(), s.conn.getCurrentlyStreaming() || streamable)
   795  		switch {
   796  		case s.conn.getCurrentlyStreaming():
   797  			// The connection is already in a streaming state, so we stream the next response.
   798  			err = baseOperation.StreamResponse(s.heartbeatCtx, heartbeatConn)
   799  		case streamable:
   800  			// The server supports the streamable protocol. Set the socket timeout to
   801  			// connectTimeoutMS+heartbeatFrequencyMS and execute an awaitable hello request. Set conn.canStream so
   802  			// the wire message will advertise streaming support to the server.
   804  			// Calculation for maxAwaitTimeMS is taken from time.Duration.Milliseconds (added in Go 1.13).
   805  			maxAwaitTimeMS := int64(s.cfg.heartbeatInterval) / 1e6
   806  			// If connectTimeoutMS=0, the socket timeout should be infinite. Otherwise, it is connectTimeoutMS +
   807  			// heartbeatFrequencyMS to account for the fact that the query will block for heartbeatFrequencyMS
   808  			// server-side.
   809  			socketTimeout := s.cfg.heartbeatTimeout
   810  			if socketTimeout != 0 {
   811  				socketTimeout += s.cfg.heartbeatInterval
   812  			}
   813  			s.conn.setSocketTimeout(socketTimeout)
   814  			baseOperation = baseOperation.TopologyVersion(previousDescription.TopologyVersion).
   815  				MaxAwaitTimeMS(maxAwaitTimeMS)
   816  			s.conn.setCanStream(true)
   817  			err = baseOperation.Execute(s.heartbeatCtx)
   818  		default:
   819  			// The server doesn't support the awaitable protocol. Set the socket timeout to connectTimeoutMS and
   820  			// execute a regular heartbeat without any additional parameters.
   822  			s.conn.setSocketTimeout(s.cfg.heartbeatTimeout)
   823  			err = baseOperation.Execute(s.heartbeatCtx)
   824  		}
   825  		duration = time.Since(start)
   827  		if err == nil {
   828  			tempDesc := baseOperation.Result(s.address)
   829  			descPtr = &tempDesc
   830  			s.publishServerHeartbeatSucceededEvent(s.conn.ID(), duration, tempDesc, s.conn.getCurrentlyStreaming() || streamable)
   831  		} else {
   832  			// Close the connection here rather than below so we ensure we're not closing a connection that wasn't
   833  			// successfully created.
   834  			if s.conn != nil {
   835  				_ = s.conn.close()
   836  			}
   837  			s.publishServerHeartbeatFailedEvent(s.conn.ID(), duration, err, s.conn.getCurrentlyStreaming() || streamable)
   838  		}
   839  	}
   841  	if descPtr != nil {
   842  		// The check was successful. Set the average RTT and the 90th percentile RTT and return.
   843  		desc := *descPtr
   844  		desc = desc.SetAverageRTT(s.rttMonitor.EWMA())
   845  		desc.HeartbeatInterval = s.cfg.heartbeatInterval
   846  		return desc, nil
   847  	}
   849  	if s.checkWasCancelled() {
   850  		// If the previous check was cancelled, we don't want to clear the pool. Return a sentinel error so the caller
   851  		// will know that an actual error didn't occur.
   852  		return emptyDescription, errCheckCancelled
   853  	}
   855  	// An error occurred. We reset the RTT monitor for all errors and return an Unknown description. The pool must also
   856  	// be cleared, but only after the description has already been updated, so that is handled by the caller.
   857  	topologyVersion := extractTopologyVersion(err)
   858  	s.rttMonitor.reset()
   859  	return description.NewServerFromError(s.address, err, topologyVersion), nil
   860  }
   862  func extractTopologyVersion(err error) *description.TopologyVersion {
   863  	if ce, ok := err.(ConnectionError); ok {
   864  		err = ce.Wrapped
   865  	}
   867  	switch converted := err.(type) {
   868  	case driver.Error:
   869  		return converted.TopologyVersion
   870  	case driver.WriteCommandError:
   871  		if converted.WriteConcernError != nil {
   872  			return converted.WriteConcernError.TopologyVersion
   873  		}
   874  	}
   876  	return nil
   877  }
   879  // RTTMonitor returns this server's round-trip-time monitor.
   880  func (s *Server) RTTMonitor() driver.RTTMonitor {
   881  	return s.rttMonitor
   882  }
   884  // OperationCount returns the current number of in-progress operations for this server.
   885  func (s *Server) OperationCount() int64 {
   886  	return atomic.LoadInt64(&s.operationCount)
   887  }
   889  // String implements the Stringer interface.
   890  func (s *Server) String() string {
   891  	desc := s.Description()
   892  	state := atomic.LoadInt64(&s.state)
   893  	str := fmt.Sprintf("Addr: %s, Type: %s, State: %s",
   894  		s.address, desc.Kind, serverStateString(state))
   895  	if len(desc.Tags) != 0 {
   896  		str += fmt.Sprintf(", Tag sets: %s", desc.Tags)
   897  	}
   898  	if state == serverConnected {
   899  		str += fmt.Sprintf(", Average RTT: %s, Min RTT: %s", desc.AverageRTT, s.RTTMonitor().Min())
   900  	}
   901  	if desc.LastError != nil {
   902  		str += fmt.Sprintf(", Last error: %s", desc.LastError)
   903  	}
   905  	return str
   906  }
   908  // ServerSubscription represents a subscription to the description.Server updates for
   909  // a specific server.
   910  type ServerSubscription struct {
   911  	C  <-chan description.Server
   912  	s  *Server
   913  	id uint64
   914  }
   916  // Unsubscribe unsubscribes this ServerSubscription from updates and closes the
   917  // subscription channel.
   918  func (ss *ServerSubscription) Unsubscribe() error {
   919  	ss.s.subLock.Lock()
   920  	defer ss.s.subLock.Unlock()
   921  	if ss.s.subscriptionsClosed {
   922  		return nil
   923  	}
   925  	ch, ok := ss.s.subscribers[]
   926  	if !ok {
   927  		return nil
   928  	}
   930  	close(ch)
   931  	delete(ss.s.subscribers,
   933  	return nil
   934  }
   936  // publishes a ServerOpeningEvent to indicate the server is being initialized
   937  func (s *Server) publishServerOpeningEvent(addr address.Address) {
   938  	if s == nil {
   939  		return
   940  	}
   942  	serverOpening := &event.ServerOpeningEvent{
   943  		Address:    addr,
   944  		TopologyID: s.topologyID,
   945  	}
   947  	if s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerOpening != nil {
   948  		s.cfg.serverMonitor.ServerOpening(serverOpening)
   949  	}
   950  }
   952  // publishes a ServerHeartbeatStartedEvent to indicate a hello command has started
   953  func (s *Server) publishServerHeartbeatStartedEvent(connectionID string, await bool) {
   954  	serverHeartbeatStarted := &event.ServerHeartbeatStartedEvent{
   955  		ConnectionID: connectionID,
   956  		Awaited:      await,
   957  	}
   959  	if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatStarted != nil {
   960  		s.cfg.serverMonitor.ServerHeartbeatStarted(serverHeartbeatStarted)
   961  	}
   962  }
   964  // publishes a ServerHeartbeatSucceededEvent to indicate hello has succeeded
   965  func (s *Server) publishServerHeartbeatSucceededEvent(connectionID string,
   966  	duration time.Duration,
   967  	desc description.Server,
   968  	await bool,
   969  ) {
   970  	serverHeartbeatSucceeded := &event.ServerHeartbeatSucceededEvent{
   971  		DurationNanos: duration.Nanoseconds(),
   972  		Duration:      duration,
   973  		Reply:         desc,
   974  		ConnectionID:  connectionID,
   975  		Awaited:       await,
   976  	}
   978  	if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatSucceeded != nil {
   979  		s.cfg.serverMonitor.ServerHeartbeatSucceeded(serverHeartbeatSucceeded)
   980  	}
   981  }
   983  // publishes a ServerHeartbeatFailedEvent to indicate hello has failed
   984  func (s *Server) publishServerHeartbeatFailedEvent(connectionID string,
   985  	duration time.Duration,
   986  	err error,
   987  	await bool,
   988  ) {
   989  	serverHeartbeatFailed := &event.ServerHeartbeatFailedEvent{
   990  		DurationNanos: duration.Nanoseconds(),
   991  		Duration:      duration,
   992  		Failure:       err,
   993  		ConnectionID:  connectionID,
   994  		Awaited:       await,
   995  	}
   997  	if s != nil && s.cfg.serverMonitor != nil && s.cfg.serverMonitor.ServerHeartbeatFailed != nil {
   998  		s.cfg.serverMonitor.ServerHeartbeatFailed(serverHeartbeatFailed)
   999  	}
  1000  }
  1002  // unwrapConnectionError returns the connection error wrapped by err, or nil if err does not wrap a connection error.
  1003  func unwrapConnectionError(err error) error {
  1004  	// This is essentially an implementation of errors.As to unwrap this error until we get a ConnectionError and then
  1005  	// return ConnectionError.Wrapped.
  1007  	connErr, ok := err.(ConnectionError)
  1008  	if ok {
  1009  		return connErr.Wrapped
  1010  	}
  1012  	driverErr, ok := err.(driver.Error)
  1013  	if !ok || !driverErr.NetworkError() {
  1014  		return nil
  1015  	}
  1017  	connErr, ok = driverErr.Wrapped.(ConnectionError)
  1018  	if ok {
  1019  		return connErr.Wrapped
  1020  	}
  1022  	return nil
  1023  }