github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/mongo/open.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package mongo
     5  
     6  import (
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	stderrors "errors"
    10  	"fmt"
    11  	"net"
    12  	"strings"
    13  	"time"
    14  
    15  	"github.com/juju/errors"
    16  	"github.com/juju/http/v2"
    17  	"github.com/juju/mgo/v3"
    18  	"github.com/juju/names/v5"
    19  	"github.com/juju/utils/v3/cert"
    20  )
    21  
    22  // SocketTimeout should be long enough that even a slow mongo server
    23  // will respond in that length of time, and must also be long enough
    24  // to allow for completion of heavyweight queries.
    25  //
    26  // Note: 1 minute is mgo's default socket timeout value.
    27  //
    28  // Also note: We have observed mongodb occasionally getting "stuck"
    29  // for over 30s in the field.
    30  const SocketTimeout = time.Minute
    31  
    32  // defaultDialTimeout should be representative of the upper bound of
    33  // time taken to dial a mongo server from within the same
    34  // cloud/private network.
    35  const defaultDialTimeout = 30 * time.Second
    36  
    37  // DialOpts holds configuration parameters that control the
    38  // Dialing behavior when connecting to a controller.
    39  type DialOpts struct {
    40  	// Timeout is the amount of time to wait contacting
    41  	// a controller.
    42  	Timeout time.Duration
    43  
    44  	// SocketTimeout is the amount of time to wait for a
    45  	// non-responding socket to the database before it is forcefully
    46  	// closed. If this is zero, the value of the SocketTimeout const
    47  	// will be used.
    48  	SocketTimeout time.Duration
    49  
    50  	// Direct informs whether to establish connections only with the
    51  	// specified seed servers, or to obtain information for the whole
    52  	// cluster and establish connections with further servers too.
    53  	Direct bool
    54  
    55  	// PostDial, if non-nil, is called by DialWithInfo with the
    56  	// mgo.Session after a successful dial but before DialWithInfo
    57  	// returns to its caller.
    58  	PostDial func(*mgo.Session) error
    59  
    60  	// PostDialServer, if non-nil, is called by DialWithInfo after
    61  	// dialing a MongoDB server connection, successfully or not.
    62  	// The address dialed and amount of time taken are included,
    63  	// as well as the error if any.
    64  	PostDialServer func(addr string, _ time.Duration, _ error)
    65  
    66  	// PoolLimit defines the per-server socket pool limit
    67  	PoolLimit int
    68  }
    69  
    70  // DefaultDialOpts returns a DialOpts representing the default
    71  // parameters for contacting a controller.
    72  //
    73  // NOTE(axw) these options are inappropriate for tests in CI,
    74  // as CI tends to run on machines with slow I/O (or thrashed
    75  // I/O with limited IOPs). For tests, use mongotest.DialOpts().
    76  func DefaultDialOpts() DialOpts {
    77  	return DialOpts{
    78  		Timeout:       defaultDialTimeout,
    79  		SocketTimeout: SocketTimeout,
    80  	}
    81  }
    82  
    83  // Info encapsulates information about cluster of
    84  // mongo servers and can be used to make a
    85  // connection to that cluster.
    86  type Info struct {
    87  	// Addrs gives the addresses of the MongoDB servers for the state.
    88  	// Each address should be in the form address:port.
    89  	Addrs []string
    90  
    91  	// CACert holds the CA certificate that will be used
    92  	// to validate the controller's certificate, in PEM format.
    93  	CACert string
    94  
    95  	// DisableTLS controls whether the connection to MongoDB servers
    96  	// is made using TLS (the default), or not.
    97  	DisableTLS bool
    98  }
    99  
   100  // MongoInfo encapsulates information about cluster of
   101  // servers holding juju state and can be used to make a
   102  // connection to that cluster.
   103  type MongoInfo struct {
   104  	// mongo.Info contains the addresses and cert of the mongo cluster.
   105  	Info
   106  
   107  	// Tag holds the name of the entity that is connecting.
   108  	// It should be nil when connecting as an administrator.
   109  	Tag names.Tag
   110  
   111  	// Password holds the password for the connecting entity.
   112  	Password string
   113  }
   114  
   115  // DialInfo returns information on how to dial
   116  // the state's mongo server with the given info
   117  // and dial options.
   118  func DialInfo(info Info, opts DialOpts) (*mgo.DialInfo, error) {
   119  	if len(info.Addrs) == 0 {
   120  		return nil, stderrors.New("no mongo addresses")
   121  	}
   122  
   123  	var tlsConfig *tls.Config
   124  	if !info.DisableTLS {
   125  		if len(info.CACert) == 0 {
   126  			return nil, stderrors.New("missing CA certificate")
   127  		}
   128  		xcert, err := cert.ParseCert(info.CACert)
   129  		if err != nil {
   130  			return nil, fmt.Errorf("cannot parse CA certificate: %v", err)
   131  		}
   132  		pool := x509.NewCertPool()
   133  		pool.AddCert(xcert)
   134  
   135  		tlsConfig = http.SecureTLSConfig()
   136  		tlsConfig.RootCAs = pool
   137  		tlsConfig.ServerName = "juju-mongodb"
   138  
   139  		// TODO(natefinch): revisit this when are full-time on mongo 3.
   140  		// We have to add non-ECDHE suites because mongo doesn't support ECDHE.
   141  		moreSuites := []uint16{
   142  			tls.TLS_RSA_WITH_AES_128_GCM_SHA256,
   143  			tls.TLS_RSA_WITH_AES_256_GCM_SHA384,
   144  		}
   145  		tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, moreSuites...)
   146  	}
   147  
   148  	dial := func(server *mgo.ServerAddr) (_ net.Conn, err error) {
   149  		if opts.PostDialServer != nil {
   150  			before := time.Now()
   151  			defer func() {
   152  				taken := time.Now().Sub(before)
   153  				opts.PostDialServer(server.String(), taken, err)
   154  			}()
   155  		}
   156  
   157  		addr := server.TCPAddr().String()
   158  		c, err := net.DialTimeout("tcp", addr, opts.Timeout)
   159  		if err != nil {
   160  			logger.Debugf("mongodb connection failed, will retry: %v", err)
   161  			return nil, err
   162  		}
   163  		if tlsConfig != nil {
   164  			cc := tls.Client(c, tlsConfig)
   165  			if err := cc.Handshake(); err != nil {
   166  				logger.Warningf("TLS handshake failed: %v", err)
   167  				if err := c.Close(); err != nil {
   168  					logger.Warningf("failed to close connection: %v", err)
   169  				}
   170  				return nil, err
   171  			}
   172  			c = cc
   173  		}
   174  		logger.Debugf("dialed mongodb server at %q", addr)
   175  		return c, nil
   176  	}
   177  
   178  	return &mgo.DialInfo{
   179  		Addrs:      info.Addrs,
   180  		Timeout:    opts.Timeout,
   181  		DialServer: dial,
   182  		Direct:     opts.Direct,
   183  		PoolLimit:  opts.PoolLimit,
   184  	}, nil
   185  }
   186  
   187  // DialWithInfo establishes a new session to the cluster identified by info,
   188  // with the specified options. If either Tag or Password are specified, then
   189  // a Login call on the admin database will be made.
   190  func DialWithInfo(info MongoInfo, opts DialOpts) (*mgo.Session, error) {
   191  	if opts.Timeout == 0 {
   192  		return nil, errors.New("a non-zero Timeout must be specified")
   193  	}
   194  
   195  	dialInfo, err := DialInfo(info.Info, opts)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  
   200  	session, err := mgo.DialWithInfo(dialInfo)
   201  	if err != nil {
   202  		return nil, err
   203  	}
   204  
   205  	if opts.SocketTimeout == 0 {
   206  		opts.SocketTimeout = SocketTimeout
   207  	}
   208  	session.SetSocketTimeout(opts.SocketTimeout)
   209  
   210  	if opts.PostDial != nil {
   211  		if err := opts.PostDial(session); err != nil {
   212  			session.Close()
   213  			return nil, errors.Annotate(err, "PostDial failed")
   214  		}
   215  	}
   216  	if info.Tag != nil || info.Password != "" {
   217  		user := AdminUser
   218  		if info.Tag != nil {
   219  			user = info.Tag.String()
   220  		}
   221  		if err := Login(session, user, info.Password); err != nil {
   222  			session.Close()
   223  			return nil, errors.Trace(err)
   224  		}
   225  	}
   226  	return session, nil
   227  }
   228  
   229  // Login logs in to the mongodb admin database.
   230  func Login(session *mgo.Session, user, password string) error {
   231  	admin := session.DB("admin")
   232  	if err := admin.Login(user, password); err != nil {
   233  		return MaybeUnauthorizedf(err, "cannot log in to admin database as %q", user)
   234  	}
   235  	return nil
   236  }
   237  
   238  // MaybeUnauthorizedf checks if the cause of the given error is a Mongo
   239  // authorization error, and if so, wraps the error with errors.Unauthorizedf.
   240  func MaybeUnauthorizedf(err error, message string, args ...interface{}) error {
   241  	if isUnauthorized(errors.Cause(err)) {
   242  		err = errors.Unauthorizedf("unauthorized mongo access: %s", err)
   243  	}
   244  	return errors.Annotatef(err, message, args...)
   245  }
   246  
   247  func isUnauthorized(err error) bool {
   248  	if err == nil {
   249  		return false
   250  	}
   251  	// Some unauthorized access errors have no error code,
   252  	// just a simple error string; and some do have error codes
   253  	// but are not of consistent types (LastError/QueryError).
   254  	for _, prefix := range []string{
   255  		"auth fail",
   256  		"not authorized",
   257  		"server returned error on SASL authentication step: Authentication failed.",
   258  	} {
   259  		if strings.HasPrefix(err.Error(), prefix) {
   260  			return true
   261  		}
   262  	}
   263  	if err, ok := err.(*mgo.QueryError); ok {
   264  		return err.Code == 10057 ||
   265  			err.Code == 13 ||
   266  			err.Message == "need to login" ||
   267  			err.Message == "unauthorized"
   268  	}
   269  	return false
   270  }