launchpad.net/~rogpeppe/juju-core/500-errgo-fix@v0.0.0-20140213181702-000000002356/testing/mgo.go (about) 1 // Copyright 2012, 2013 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package testing 5 6 import ( 7 "bufio" 8 "crypto/tls" 9 "crypto/x509" 10 "fmt" 11 "io" 12 "io/ioutil" 13 "net" 14 "os" 15 "os/exec" 16 "path/filepath" 17 "strconv" 18 "strings" 19 stdtesting "testing" 20 "time" 21 22 "labix.org/v2/mgo" 23 gc "launchpad.net/gocheck" 24 25 "launchpad.net/juju-core/cert" 26 "launchpad.net/juju-core/log" 27 "launchpad.net/juju-core/utils" 28 ) 29 30 var ( 31 // MgoServer is a shared mongo server used by tests. 32 MgoServer = &MgoInstance{ssl: true} 33 ) 34 35 type MgoInstance struct { 36 // addr holds the address of the MongoDB server 37 addr string 38 39 // MgoPort holds the port of the MongoDB server. 40 port int 41 42 // server holds the running MongoDB command. 43 server *exec.Cmd 44 45 // exited receives a value when the mongodb server exits. 46 exited <-chan struct{} 47 48 // dir holds the directory that MongoDB is running in. 49 dir string 50 51 // ssl determines whether the MongoDB server will use TLS 52 ssl bool 53 54 // Params is a list of additional parameters that will be passed to 55 // the mongod application 56 Params []string 57 } 58 59 // Addr returns the address of the MongoDB server. 60 func (m *MgoInstance) Addr() string { 61 return m.addr 62 } 63 64 // Port returns the port of the MongoDB server. 65 func (m *MgoInstance) Port() int { 66 return m.port 67 } 68 69 // We specify a timeout to mgo.Dial, to prevent 70 // mongod failures hanging the tests. 71 const mgoDialTimeout = 15 * time.Second 72 73 // MgoSuite is a suite that deletes all content from the shared MongoDB 74 // server at the end of every test and supplies a connection to the shared 75 // MongoDB server. 76 type MgoSuite struct { 77 Session *mgo.Session 78 } 79 80 // Start starts a MongoDB server in a temporary directory. 81 func (inst *MgoInstance) Start(ssl bool) error { 82 dbdir, err := ioutil.TempDir("", "test-mgo") 83 if err != nil { 84 return err 85 } 86 87 // give them all the same keyfile so they can talk appropriately 88 keyFilePath := filepath.Join(dbdir, "keyfile") 89 err = ioutil.WriteFile(keyFilePath, []byte("not very secret"), 0600) 90 if err != nil { 91 return fmt.Errorf("cannot write key file: %v", err) 92 } 93 94 pemPath := filepath.Join(dbdir, "server.pem") 95 err = ioutil.WriteFile(pemPath, []byte(ServerCert+ServerKey), 0600) 96 if err != nil { 97 return fmt.Errorf("cannot write cert/key PEM: %v", err) 98 } 99 inst.port = FindTCPPort() 100 inst.addr = fmt.Sprintf("localhost:%d", inst.port) 101 inst.dir = dbdir 102 inst.ssl = ssl 103 if err := inst.run(); err != nil { 104 inst.addr = "" 105 inst.port = 0 106 os.RemoveAll(inst.dir) 107 inst.dir = "" 108 } 109 return err 110 } 111 112 // run runs the MongoDB server at the 113 // address and directory already configured. 114 func (inst *MgoInstance) run() error { 115 if inst.server != nil { 116 panic("mongo server is already running") 117 } 118 119 mgoport := strconv.Itoa(inst.port) 120 mgoargs := []string{ 121 "--auth", 122 "--dbpath", inst.dir, 123 "--port", mgoport, 124 "--nssize", "1", 125 "--noprealloc", 126 "--smallfiles", 127 "--nojournal", 128 "--nounixsocket", 129 "--oplogSize", "10", 130 "--keyFile", filepath.Join(inst.dir, "keyfile"), 131 } 132 if inst.ssl { 133 mgoargs = append(mgoargs, 134 "--sslOnNormalPorts", 135 "--sslPEMKeyFile", filepath.Join(inst.dir, "server.pem"), 136 "--sslPEMKeyPassword", "ignored") 137 } 138 if inst.Params != nil { 139 mgoargs = append(mgoargs, inst.Params...) 140 } 141 server := exec.Command("mongod", mgoargs...) 142 out, err := server.StdoutPipe() 143 if err != nil { 144 return err 145 } 146 server.Stderr = server.Stdout 147 exited := make(chan struct{}) 148 go func() { 149 lines := readLines(out, 20) 150 err := server.Wait() 151 exitErr, _ := err.(*exec.ExitError) 152 if err == nil || exitErr != nil && exitErr.Exited() { 153 // mongodb has exited without being killed, so print the 154 // last few lines of its log output. 155 for _, line := range lines { 156 log.Infof("mongod: %s", line) 157 } 158 } 159 close(exited) 160 }() 161 inst.exited = exited 162 if err := server.Start(); err != nil { 163 return err 164 } 165 inst.server = server 166 167 return nil 168 } 169 170 func (inst *MgoInstance) kill() { 171 inst.server.Process.Kill() 172 <-inst.exited 173 inst.server = nil 174 inst.exited = nil 175 } 176 177 func (inst *MgoInstance) Destroy() { 178 if inst.server != nil { 179 inst.kill() 180 os.RemoveAll(inst.dir) 181 inst.addr, inst.dir = "", "" 182 } 183 } 184 185 // Restart restarts the mongo server, useful for 186 // testing what happens when a state server goes down. 187 func (inst *MgoInstance) Restart() { 188 inst.kill() 189 if err := inst.Start(inst.ssl); err != nil { 190 panic(err) 191 } 192 } 193 194 // MgoTestPackage should be called to register the tests for any package that 195 // requires a MongoDB server. 196 func MgoTestPackage(t *stdtesting.T) { 197 MgoTestPackageSsl(t, true) 198 } 199 200 func MgoTestPackageSsl(t *stdtesting.T, ssl bool) { 201 if err := MgoServer.Start(ssl); err != nil { 202 t.Fatal(err) 203 } 204 defer MgoServer.Destroy() 205 gc.TestingT(t) 206 } 207 208 func (s *MgoSuite) SetUpSuite(c *gc.C) { 209 if MgoServer.addr == "" { 210 panic("MgoSuite tests must be run with MgoTestPackage") 211 } 212 mgo.SetStats(true) 213 // Make tests that use password authentication faster. 214 utils.FastInsecureHash = true 215 } 216 217 // readLines reads lines from the given reader and returns 218 // the last n non-empty lines, ignoring empty lines. 219 func readLines(r io.Reader, n int) []string { 220 br := bufio.NewReader(r) 221 lines := make([]string, n) 222 i := 0 223 for { 224 line, err := br.ReadString('\n') 225 if line = strings.TrimRight(line, "\n"); line != "" { 226 lines[i%n] = line 227 i++ 228 } 229 if err != nil { 230 break 231 } 232 } 233 final := make([]string, 0, n+1) 234 if i > n { 235 final = append(final, fmt.Sprintf("[%d lines omitted]", i-n)) 236 } 237 for j := 0; j < n; j++ { 238 if line := lines[(j+i)%n]; line != "" { 239 final = append(final, line) 240 } 241 } 242 return final 243 } 244 245 func (s *MgoSuite) TearDownSuite(c *gc.C) { 246 utils.FastInsecureHash = false 247 } 248 249 // MustDial returns a new connection to the MongoDB server, and panics on 250 // errors. 251 func (inst *MgoInstance) MustDial() *mgo.Session { 252 s, err := inst.dial(false) 253 if err != nil { 254 panic(err) 255 } 256 return s 257 } 258 259 // Dial returns a new connection to the MongoDB server. 260 func (inst *MgoInstance) Dial() (*mgo.Session, error) { 261 return inst.dial(false) 262 } 263 264 // DialDirect returns a new direct connection to the shared MongoDB server. This 265 // must be used if you're connecting to a replicaset that hasn't been initiated 266 // yet. 267 func (inst *MgoInstance) DialDirect() (*mgo.Session, error) { 268 return inst.dial(true) 269 } 270 271 // MustDialDirect works like DialDirect, but panics on errors. 272 func (inst *MgoInstance) MustDialDirect() *mgo.Session { 273 session, err := inst.dial(true) 274 if err != nil { 275 panic(err) 276 } 277 return session 278 } 279 280 func (inst *MgoInstance) dial(direct bool) (*mgo.Session, error) { 281 pool := x509.NewCertPool() 282 xcert, err := cert.ParseCert([]byte(CACert)) 283 if err != nil { 284 return nil, err 285 } 286 pool.AddCert(xcert) 287 tlsConfig := &tls.Config{ 288 RootCAs: pool, 289 ServerName: "anything", 290 } 291 session, err := mgo.DialWithInfo(&mgo.DialInfo{ 292 Direct: direct, 293 Addrs: []string{inst.addr}, 294 Dial: func(addr net.Addr) (net.Conn, error) { 295 return tls.Dial("tcp", addr.String(), tlsConfig) 296 }, 297 Timeout: mgoDialTimeout, 298 }) 299 return session, err 300 } 301 302 func (s *MgoSuite) SetUpTest(c *gc.C) { 303 mgo.ResetStats() 304 s.Session = MgoServer.MustDial() 305 } 306 307 // Reset deletes all content from the MongoDB server and panics if it encounters 308 // errors. 309 func (inst *MgoInstance) Reset() { 310 session := inst.MustDial() 311 defer session.Close() 312 313 dbnames, ok := resetAdminPasswordAndFetchDBNames(session) 314 if ok { 315 log.Infof("Reset successfully reset admin password") 316 } else { 317 // We restart it to regain access. This should only 318 // happen when tests fail. 319 log.Noticef("testing: restarting MongoDB server after unauthorized access") 320 inst.Destroy() 321 if err := inst.Start(inst.ssl); err != nil { 322 panic(err) 323 } 324 return 325 } 326 for _, name := range dbnames { 327 switch name { 328 case "admin", "local", "config": 329 default: 330 if err := session.DB(name).DropDatabase(); err != nil { 331 panic(fmt.Errorf("Cannot drop MongoDB database %v: %v", name, err)) 332 } 333 } 334 } 335 } 336 337 // resetAdminPasswordAndFetchDBNames logs into the database with a 338 // plausible password and returns all the database's db names. We need 339 // to try several passwords because we don't know what state the mongo 340 // server is in when Reset is called. If the test has set a custom 341 // password, we're out of luck, but if they are using 342 // DefaultStatePassword, we can succeed. 343 func resetAdminPasswordAndFetchDBNames(session *mgo.Session) ([]string, bool) { 344 // First try with no password 345 dbnames, err := session.DatabaseNames() 346 if err == nil { 347 return dbnames, true 348 } 349 if !isUnauthorized(err) { 350 panic(err) 351 } 352 // Then try the two most likely passwords in turn. 353 for _, password := range []string{ 354 DefaultMongoPassword, 355 utils.UserPasswordHash(DefaultMongoPassword, utils.CompatSalt), 356 } { 357 admin := session.DB("admin") 358 if err := admin.Login("admin", password); err != nil { 359 log.Infof("failed to log in with password %q", password) 360 continue 361 } 362 dbnames, err := session.DatabaseNames() 363 if err == nil { 364 if err := admin.RemoveUser("admin"); err != nil { 365 panic(err) 366 } 367 return dbnames, true 368 } 369 if !isUnauthorized(err) { 370 panic(err) 371 } 372 log.Infof("unauthorized access when getting database names; password %q", password) 373 } 374 return nil, false 375 } 376 377 // isUnauthorized is a copy of the same function in state/open.go. 378 func isUnauthorized(err error) bool { 379 if err == nil { 380 return false 381 } 382 // Some unauthorized access errors have no error code, 383 // just a simple error string. 384 if err.Error() == "auth fails" { 385 return true 386 } 387 if err, ok := err.(*mgo.QueryError); ok { 388 return err.Code == 10057 || 389 err.Message == "need to login" || 390 err.Message == "unauthorized" 391 } 392 return false 393 } 394 395 func (s *MgoSuite) TearDownTest(c *gc.C) { 396 MgoServer.Reset() 397 s.Session.Close() 398 for i := 0; ; i++ { 399 stats := mgo.GetStats() 400 if stats.SocketsInUse == 0 && stats.SocketsAlive == 0 { 401 break 402 } 403 if i == 20 { 404 c.Fatal("Test left sockets in a dirty state") 405 } 406 c.Logf("Waiting for sockets to die: %d in use, %d alive", stats.SocketsInUse, stats.SocketsAlive) 407 time.Sleep(500 * time.Millisecond) 408 } 409 } 410 411 // FindTCPPort finds an unused TCP port and returns it. 412 // Use of this function has an inherent race condition - another 413 // process may claim the port before we try to use it. 414 // We hope that the probability is small enough during 415 // testing to be negligible. 416 func FindTCPPort() int { 417 l, err := net.Listen("tcp", "127.0.0.1:0") 418 if err != nil { 419 panic(err) 420 } 421 l.Close() 422 return l.Addr().(*net.TCPAddr).Port 423 }