github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/mongo/open.go (about) 1 // Copyright 2014 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package mongo 5 6 import ( 7 "crypto/tls" 8 "crypto/x509" 9 stderrors "errors" 10 "fmt" 11 "net" 12 "strings" 13 "time" 14 15 "github.com/juju/errors" 16 "github.com/juju/http/v2" 17 "github.com/juju/mgo/v3" 18 "github.com/juju/names/v5" 19 "github.com/juju/utils/v3/cert" 20 ) 21 22 // SocketTimeout should be long enough that even a slow mongo server 23 // will respond in that length of time, and must also be long enough 24 // to allow for completion of heavyweight queries. 25 // 26 // Note: 1 minute is mgo's default socket timeout value. 27 // 28 // Also note: We have observed mongodb occasionally getting "stuck" 29 // for over 30s in the field. 30 const SocketTimeout = time.Minute 31 32 // defaultDialTimeout should be representative of the upper bound of 33 // time taken to dial a mongo server from within the same 34 // cloud/private network. 35 const defaultDialTimeout = 30 * time.Second 36 37 // DialOpts holds configuration parameters that control the 38 // Dialing behavior when connecting to a controller. 39 type DialOpts struct { 40 // Timeout is the amount of time to wait contacting 41 // a controller. 42 Timeout time.Duration 43 44 // SocketTimeout is the amount of time to wait for a 45 // non-responding socket to the database before it is forcefully 46 // closed. If this is zero, the value of the SocketTimeout const 47 // will be used. 48 SocketTimeout time.Duration 49 50 // Direct informs whether to establish connections only with the 51 // specified seed servers, or to obtain information for the whole 52 // cluster and establish connections with further servers too. 53 Direct bool 54 55 // PostDial, if non-nil, is called by DialWithInfo with the 56 // mgo.Session after a successful dial but before DialWithInfo 57 // returns to its caller. 58 PostDial func(*mgo.Session) error 59 60 // PostDialServer, if non-nil, is called by DialWithInfo after 61 // dialing a MongoDB server connection, successfully or not. 62 // The address dialed and amount of time taken are included, 63 // as well as the error if any. 64 PostDialServer func(addr string, _ time.Duration, _ error) 65 66 // PoolLimit defines the per-server socket pool limit 67 PoolLimit int 68 } 69 70 // DefaultDialOpts returns a DialOpts representing the default 71 // parameters for contacting a controller. 72 // 73 // NOTE(axw) these options are inappropriate for tests in CI, 74 // as CI tends to run on machines with slow I/O (or thrashed 75 // I/O with limited IOPs). For tests, use mongotest.DialOpts(). 76 func DefaultDialOpts() DialOpts { 77 return DialOpts{ 78 Timeout: defaultDialTimeout, 79 SocketTimeout: SocketTimeout, 80 } 81 } 82 83 // Info encapsulates information about cluster of 84 // mongo servers and can be used to make a 85 // connection to that cluster. 86 type Info struct { 87 // Addrs gives the addresses of the MongoDB servers for the state. 88 // Each address should be in the form address:port. 89 Addrs []string 90 91 // CACert holds the CA certificate that will be used 92 // to validate the controller's certificate, in PEM format. 93 CACert string 94 95 // DisableTLS controls whether the connection to MongoDB servers 96 // is made using TLS (the default), or not. 97 DisableTLS bool 98 } 99 100 // MongoInfo encapsulates information about cluster of 101 // servers holding juju state and can be used to make a 102 // connection to that cluster. 103 type MongoInfo struct { 104 // mongo.Info contains the addresses and cert of the mongo cluster. 105 Info 106 107 // Tag holds the name of the entity that is connecting. 108 // It should be nil when connecting as an administrator. 109 Tag names.Tag 110 111 // Password holds the password for the connecting entity. 112 Password string 113 } 114 115 // DialInfo returns information on how to dial 116 // the state's mongo server with the given info 117 // and dial options. 118 func DialInfo(info Info, opts DialOpts) (*mgo.DialInfo, error) { 119 if len(info.Addrs) == 0 { 120 return nil, stderrors.New("no mongo addresses") 121 } 122 123 var tlsConfig *tls.Config 124 if !info.DisableTLS { 125 if len(info.CACert) == 0 { 126 return nil, stderrors.New("missing CA certificate") 127 } 128 xcert, err := cert.ParseCert(info.CACert) 129 if err != nil { 130 return nil, fmt.Errorf("cannot parse CA certificate: %v", err) 131 } 132 pool := x509.NewCertPool() 133 pool.AddCert(xcert) 134 135 tlsConfig = http.SecureTLSConfig() 136 tlsConfig.RootCAs = pool 137 tlsConfig.ServerName = "juju-mongodb" 138 139 // TODO(natefinch): revisit this when are full-time on mongo 3. 140 // We have to add non-ECDHE suites because mongo doesn't support ECDHE. 141 moreSuites := []uint16{ 142 tls.TLS_RSA_WITH_AES_128_GCM_SHA256, 143 tls.TLS_RSA_WITH_AES_256_GCM_SHA384, 144 } 145 tlsConfig.CipherSuites = append(tlsConfig.CipherSuites, moreSuites...) 146 } 147 148 dial := func(server *mgo.ServerAddr) (_ net.Conn, err error) { 149 if opts.PostDialServer != nil { 150 before := time.Now() 151 defer func() { 152 taken := time.Now().Sub(before) 153 opts.PostDialServer(server.String(), taken, err) 154 }() 155 } 156 157 addr := server.TCPAddr().String() 158 c, err := net.DialTimeout("tcp", addr, opts.Timeout) 159 if err != nil { 160 logger.Debugf("mongodb connection failed, will retry: %v", err) 161 return nil, err 162 } 163 if tlsConfig != nil { 164 cc := tls.Client(c, tlsConfig) 165 if err := cc.Handshake(); err != nil { 166 logger.Warningf("TLS handshake failed: %v", err) 167 if err := c.Close(); err != nil { 168 logger.Warningf("failed to close connection: %v", err) 169 } 170 return nil, err 171 } 172 c = cc 173 } 174 logger.Debugf("dialed mongodb server at %q", addr) 175 return c, nil 176 } 177 178 return &mgo.DialInfo{ 179 Addrs: info.Addrs, 180 Timeout: opts.Timeout, 181 DialServer: dial, 182 Direct: opts.Direct, 183 PoolLimit: opts.PoolLimit, 184 }, nil 185 } 186 187 // DialWithInfo establishes a new session to the cluster identified by info, 188 // with the specified options. If either Tag or Password are specified, then 189 // a Login call on the admin database will be made. 190 func DialWithInfo(info MongoInfo, opts DialOpts) (*mgo.Session, error) { 191 if opts.Timeout == 0 { 192 return nil, errors.New("a non-zero Timeout must be specified") 193 } 194 195 dialInfo, err := DialInfo(info.Info, opts) 196 if err != nil { 197 return nil, err 198 } 199 200 session, err := mgo.DialWithInfo(dialInfo) 201 if err != nil { 202 return nil, err 203 } 204 205 if opts.SocketTimeout == 0 { 206 opts.SocketTimeout = SocketTimeout 207 } 208 session.SetSocketTimeout(opts.SocketTimeout) 209 210 if opts.PostDial != nil { 211 if err := opts.PostDial(session); err != nil { 212 session.Close() 213 return nil, errors.Annotate(err, "PostDial failed") 214 } 215 } 216 if info.Tag != nil || info.Password != "" { 217 user := AdminUser 218 if info.Tag != nil { 219 user = info.Tag.String() 220 } 221 if err := Login(session, user, info.Password); err != nil { 222 session.Close() 223 return nil, errors.Trace(err) 224 } 225 } 226 return session, nil 227 } 228 229 // Login logs in to the mongodb admin database. 230 func Login(session *mgo.Session, user, password string) error { 231 admin := session.DB("admin") 232 if err := admin.Login(user, password); err != nil { 233 return MaybeUnauthorizedf(err, "cannot log in to admin database as %q", user) 234 } 235 return nil 236 } 237 238 // MaybeUnauthorizedf checks if the cause of the given error is a Mongo 239 // authorization error, and if so, wraps the error with errors.Unauthorizedf. 240 func MaybeUnauthorizedf(err error, message string, args ...interface{}) error { 241 if isUnauthorized(errors.Cause(err)) { 242 err = errors.Unauthorizedf("unauthorized mongo access: %s", err) 243 } 244 return errors.Annotatef(err, message, args...) 245 } 246 247 func isUnauthorized(err error) bool { 248 if err == nil { 249 return false 250 } 251 // Some unauthorized access errors have no error code, 252 // just a simple error string; and some do have error codes 253 // but are not of consistent types (LastError/QueryError). 254 for _, prefix := range []string{ 255 "auth fail", 256 "not authorized", 257 "server returned error on SASL authentication step: Authentication failed.", 258 } { 259 if strings.HasPrefix(err.Error(), prefix) { 260 return true 261 } 262 } 263 if err, ok := err.(*mgo.QueryError); ok { 264 return err.Code == 10057 || 265 err.Code == 13 || 266 err.Message == "need to login" || 267 err.Message == "unauthorized" 268 } 269 return false 270 }