launchpad.net/~rogpeppe/juju-core/500-errgo-fix@v0.0.0-20140213181702-000000002356/testing/mgo.go (about)

     1  // Copyright 2012, 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package testing
     5  
     6  import (
     7  	"bufio"
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"fmt"
    11  	"io"
    12  	"io/ioutil"
    13  	"net"
    14  	"os"
    15  	"os/exec"
    16  	"path/filepath"
    17  	"strconv"
    18  	"strings"
    19  	stdtesting "testing"
    20  	"time"
    21  
    22  	"labix.org/v2/mgo"
    23  	gc "launchpad.net/gocheck"
    24  
    25  	"launchpad.net/juju-core/cert"
    26  	"launchpad.net/juju-core/log"
    27  	"launchpad.net/juju-core/utils"
    28  )
    29  
    30  var (
    31  	// MgoServer is a shared mongo server used by tests.
    32  	MgoServer = &MgoInstance{ssl: true}
    33  )
    34  
    35  type MgoInstance struct {
    36  	// addr holds the address of the MongoDB server
    37  	addr string
    38  
    39  	// MgoPort holds the port of the MongoDB server.
    40  	port int
    41  
    42  	// server holds the running MongoDB command.
    43  	server *exec.Cmd
    44  
    45  	// exited receives a value when the mongodb server exits.
    46  	exited <-chan struct{}
    47  
    48  	// dir holds the directory that MongoDB is running in.
    49  	dir string
    50  
    51  	// ssl determines whether the MongoDB server will use TLS
    52  	ssl bool
    53  
    54  	// Params is a list of additional parameters that will be passed to
    55  	// the mongod application
    56  	Params []string
    57  }
    58  
    59  // Addr returns the address of the MongoDB server.
    60  func (m *MgoInstance) Addr() string {
    61  	return m.addr
    62  }
    63  
    64  // Port returns the port of the MongoDB server.
    65  func (m *MgoInstance) Port() int {
    66  	return m.port
    67  }
    68  
    69  // We specify a timeout to mgo.Dial, to prevent
    70  // mongod failures hanging the tests.
    71  const mgoDialTimeout = 15 * time.Second
    72  
    73  // MgoSuite is a suite that deletes all content from the shared MongoDB
    74  // server at the end of every test and supplies a connection to the shared
    75  // MongoDB server.
    76  type MgoSuite struct {
    77  	Session *mgo.Session
    78  }
    79  
    80  // Start starts a MongoDB server in a temporary directory.
    81  func (inst *MgoInstance) Start(ssl bool) error {
    82  	dbdir, err := ioutil.TempDir("", "test-mgo")
    83  	if err != nil {
    84  		return err
    85  	}
    86  
    87  	// give them all the same keyfile so they can talk appropriately
    88  	keyFilePath := filepath.Join(dbdir, "keyfile")
    89  	err = ioutil.WriteFile(keyFilePath, []byte("not very secret"), 0600)
    90  	if err != nil {
    91  		return fmt.Errorf("cannot write key file: %v", err)
    92  	}
    93  
    94  	pemPath := filepath.Join(dbdir, "server.pem")
    95  	err = ioutil.WriteFile(pemPath, []byte(ServerCert+ServerKey), 0600)
    96  	if err != nil {
    97  		return fmt.Errorf("cannot write cert/key PEM: %v", err)
    98  	}
    99  	inst.port = FindTCPPort()
   100  	inst.addr = fmt.Sprintf("localhost:%d", inst.port)
   101  	inst.dir = dbdir
   102  	inst.ssl = ssl
   103  	if err := inst.run(); err != nil {
   104  		inst.addr = ""
   105  		inst.port = 0
   106  		os.RemoveAll(inst.dir)
   107  		inst.dir = ""
   108  	}
   109  	return err
   110  }
   111  
   112  // run runs the MongoDB server at the
   113  // address and directory already configured.
   114  func (inst *MgoInstance) run() error {
   115  	if inst.server != nil {
   116  		panic("mongo server is already running")
   117  	}
   118  
   119  	mgoport := strconv.Itoa(inst.port)
   120  	mgoargs := []string{
   121  		"--auth",
   122  		"--dbpath", inst.dir,
   123  		"--port", mgoport,
   124  		"--nssize", "1",
   125  		"--noprealloc",
   126  		"--smallfiles",
   127  		"--nojournal",
   128  		"--nounixsocket",
   129  		"--oplogSize", "10",
   130  		"--keyFile", filepath.Join(inst.dir, "keyfile"),
   131  	}
   132  	if inst.ssl {
   133  		mgoargs = append(mgoargs,
   134  			"--sslOnNormalPorts",
   135  			"--sslPEMKeyFile", filepath.Join(inst.dir, "server.pem"),
   136  			"--sslPEMKeyPassword", "ignored")
   137  	}
   138  	if inst.Params != nil {
   139  		mgoargs = append(mgoargs, inst.Params...)
   140  	}
   141  	server := exec.Command("mongod", mgoargs...)
   142  	out, err := server.StdoutPipe()
   143  	if err != nil {
   144  		return err
   145  	}
   146  	server.Stderr = server.Stdout
   147  	exited := make(chan struct{})
   148  	go func() {
   149  		lines := readLines(out, 20)
   150  		err := server.Wait()
   151  		exitErr, _ := err.(*exec.ExitError)
   152  		if err == nil || exitErr != nil && exitErr.Exited() {
   153  			// mongodb has exited without being killed, so print the
   154  			// last few lines of its log output.
   155  			for _, line := range lines {
   156  				log.Infof("mongod: %s", line)
   157  			}
   158  		}
   159  		close(exited)
   160  	}()
   161  	inst.exited = exited
   162  	if err := server.Start(); err != nil {
   163  		return err
   164  	}
   165  	inst.server = server
   166  
   167  	return nil
   168  }
   169  
   170  func (inst *MgoInstance) kill() {
   171  	inst.server.Process.Kill()
   172  	<-inst.exited
   173  	inst.server = nil
   174  	inst.exited = nil
   175  }
   176  
   177  func (inst *MgoInstance) Destroy() {
   178  	if inst.server != nil {
   179  		inst.kill()
   180  		os.RemoveAll(inst.dir)
   181  		inst.addr, inst.dir = "", ""
   182  	}
   183  }
   184  
   185  // Restart restarts the mongo server, useful for
   186  // testing what happens when a state server goes down.
   187  func (inst *MgoInstance) Restart() {
   188  	inst.kill()
   189  	if err := inst.Start(inst.ssl); err != nil {
   190  		panic(err)
   191  	}
   192  }
   193  
   194  // MgoTestPackage should be called to register the tests for any package that
   195  // requires a MongoDB server.
   196  func MgoTestPackage(t *stdtesting.T) {
   197  	MgoTestPackageSsl(t, true)
   198  }
   199  
   200  func MgoTestPackageSsl(t *stdtesting.T, ssl bool) {
   201  	if err := MgoServer.Start(ssl); err != nil {
   202  		t.Fatal(err)
   203  	}
   204  	defer MgoServer.Destroy()
   205  	gc.TestingT(t)
   206  }
   207  
   208  func (s *MgoSuite) SetUpSuite(c *gc.C) {
   209  	if MgoServer.addr == "" {
   210  		panic("MgoSuite tests must be run with MgoTestPackage")
   211  	}
   212  	mgo.SetStats(true)
   213  	// Make tests that use password authentication faster.
   214  	utils.FastInsecureHash = true
   215  }
   216  
   217  // readLines reads lines from the given reader and returns
   218  // the last n non-empty lines, ignoring empty lines.
   219  func readLines(r io.Reader, n int) []string {
   220  	br := bufio.NewReader(r)
   221  	lines := make([]string, n)
   222  	i := 0
   223  	for {
   224  		line, err := br.ReadString('\n')
   225  		if line = strings.TrimRight(line, "\n"); line != "" {
   226  			lines[i%n] = line
   227  			i++
   228  		}
   229  		if err != nil {
   230  			break
   231  		}
   232  	}
   233  	final := make([]string, 0, n+1)
   234  	if i > n {
   235  		final = append(final, fmt.Sprintf("[%d lines omitted]", i-n))
   236  	}
   237  	for j := 0; j < n; j++ {
   238  		if line := lines[(j+i)%n]; line != "" {
   239  			final = append(final, line)
   240  		}
   241  	}
   242  	return final
   243  }
   244  
   245  func (s *MgoSuite) TearDownSuite(c *gc.C) {
   246  	utils.FastInsecureHash = false
   247  }
   248  
   249  // MustDial returns a new connection to the MongoDB server, and panics on
   250  // errors.
   251  func (inst *MgoInstance) MustDial() *mgo.Session {
   252  	s, err := inst.dial(false)
   253  	if err != nil {
   254  		panic(err)
   255  	}
   256  	return s
   257  }
   258  
   259  // Dial returns a new connection to the MongoDB server.
   260  func (inst *MgoInstance) Dial() (*mgo.Session, error) {
   261  	return inst.dial(false)
   262  }
   263  
   264  // DialDirect returns a new direct connection to the shared MongoDB server. This
   265  // must be used if you're connecting to a replicaset that hasn't been initiated
   266  // yet.
   267  func (inst *MgoInstance) DialDirect() (*mgo.Session, error) {
   268  	return inst.dial(true)
   269  }
   270  
   271  // MustDialDirect works like DialDirect, but panics on errors.
   272  func (inst *MgoInstance) MustDialDirect() *mgo.Session {
   273  	session, err := inst.dial(true)
   274  	if err != nil {
   275  		panic(err)
   276  	}
   277  	return session
   278  }
   279  
   280  func (inst *MgoInstance) dial(direct bool) (*mgo.Session, error) {
   281  	pool := x509.NewCertPool()
   282  	xcert, err := cert.ParseCert([]byte(CACert))
   283  	if err != nil {
   284  		return nil, err
   285  	}
   286  	pool.AddCert(xcert)
   287  	tlsConfig := &tls.Config{
   288  		RootCAs:    pool,
   289  		ServerName: "anything",
   290  	}
   291  	session, err := mgo.DialWithInfo(&mgo.DialInfo{
   292  		Direct: direct,
   293  		Addrs:  []string{inst.addr},
   294  		Dial: func(addr net.Addr) (net.Conn, error) {
   295  			return tls.Dial("tcp", addr.String(), tlsConfig)
   296  		},
   297  		Timeout: mgoDialTimeout,
   298  	})
   299  	return session, err
   300  }
   301  
   302  func (s *MgoSuite) SetUpTest(c *gc.C) {
   303  	mgo.ResetStats()
   304  	s.Session = MgoServer.MustDial()
   305  }
   306  
   307  // Reset deletes all content from the MongoDB server and panics if it encounters
   308  // errors.
   309  func (inst *MgoInstance) Reset() {
   310  	session := inst.MustDial()
   311  	defer session.Close()
   312  
   313  	dbnames, ok := resetAdminPasswordAndFetchDBNames(session)
   314  	if ok {
   315  		log.Infof("Reset successfully reset admin password")
   316  	} else {
   317  		// We restart it to regain access.  This should only
   318  		// happen when tests fail.
   319  		log.Noticef("testing: restarting MongoDB server after unauthorized access")
   320  		inst.Destroy()
   321  		if err := inst.Start(inst.ssl); err != nil {
   322  			panic(err)
   323  		}
   324  		return
   325  	}
   326  	for _, name := range dbnames {
   327  		switch name {
   328  		case "admin", "local", "config":
   329  		default:
   330  			if err := session.DB(name).DropDatabase(); err != nil {
   331  				panic(fmt.Errorf("Cannot drop MongoDB database %v: %v", name, err))
   332  			}
   333  		}
   334  	}
   335  }
   336  
   337  // resetAdminPasswordAndFetchDBNames logs into the database with a
   338  // plausible password and returns all the database's db names. We need
   339  // to try several passwords because we don't know what state the mongo
   340  // server is in when Reset is called. If the test has set a custom
   341  // password, we're out of luck, but if they are using
   342  // DefaultStatePassword, we can succeed.
   343  func resetAdminPasswordAndFetchDBNames(session *mgo.Session) ([]string, bool) {
   344  	// First try with no password
   345  	dbnames, err := session.DatabaseNames()
   346  	if err == nil {
   347  		return dbnames, true
   348  	}
   349  	if !isUnauthorized(err) {
   350  		panic(err)
   351  	}
   352  	// Then try the two most likely passwords in turn.
   353  	for _, password := range []string{
   354  		DefaultMongoPassword,
   355  		utils.UserPasswordHash(DefaultMongoPassword, utils.CompatSalt),
   356  	} {
   357  		admin := session.DB("admin")
   358  		if err := admin.Login("admin", password); err != nil {
   359  			log.Infof("failed to log in with password %q", password)
   360  			continue
   361  		}
   362  		dbnames, err := session.DatabaseNames()
   363  		if err == nil {
   364  			if err := admin.RemoveUser("admin"); err != nil {
   365  				panic(err)
   366  			}
   367  			return dbnames, true
   368  		}
   369  		if !isUnauthorized(err) {
   370  			panic(err)
   371  		}
   372  		log.Infof("unauthorized access when getting database names; password %q", password)
   373  	}
   374  	return nil, false
   375  }
   376  
   377  // isUnauthorized is a copy of the same function in state/open.go.
   378  func isUnauthorized(err error) bool {
   379  	if err == nil {
   380  		return false
   381  	}
   382  	// Some unauthorized access errors have no error code,
   383  	// just a simple error string.
   384  	if err.Error() == "auth fails" {
   385  		return true
   386  	}
   387  	if err, ok := err.(*mgo.QueryError); ok {
   388  		return err.Code == 10057 ||
   389  			err.Message == "need to login" ||
   390  			err.Message == "unauthorized"
   391  	}
   392  	return false
   393  }
   394  
   395  func (s *MgoSuite) TearDownTest(c *gc.C) {
   396  	MgoServer.Reset()
   397  	s.Session.Close()
   398  	for i := 0; ; i++ {
   399  		stats := mgo.GetStats()
   400  		if stats.SocketsInUse == 0 && stats.SocketsAlive == 0 {
   401  			break
   402  		}
   403  		if i == 20 {
   404  			c.Fatal("Test left sockets in a dirty state")
   405  		}
   406  		c.Logf("Waiting for sockets to die: %d in use, %d alive", stats.SocketsInUse, stats.SocketsAlive)
   407  		time.Sleep(500 * time.Millisecond)
   408  	}
   409  }
   410  
   411  // FindTCPPort finds an unused TCP port and returns it.
   412  // Use of this function has an inherent race condition - another
   413  // process may claim the port before we try to use it.
   414  // We hope that the probability is small enough during
   415  // testing to be negligible.
   416  func FindTCPPort() int {
   417  	l, err := net.Listen("tcp", "127.0.0.1:0")
   418  	if err != nil {
   419  		panic(err)
   420  	}
   421  	l.Close()
   422  	return l.Addr().(*net.TCPAddr).Port
   423  }