github.com/apcera/util@v0.0.0-20180322191801-7a50bc84ee48/testtool/testtool.go (about)

     1  // Copyright 2013 Apcera Inc. All rights reserved.
     2  
     3  package testtool
     4  
     5  import (
     6  	"crypto/md5"
     7  	"encoding/base64"
     8  	"flag"
     9  	"fmt"
    10  	"io/ioutil"
    11  	"math/rand"
    12  	"os"
    13  	"path"
    14  	"reflect"
    15  	"runtime"
    16  	"strings"
    17  	"testing"
    18  	"time"
    19  
    20  	"github.com/apcera/logray"
    21  	"github.com/apcera/logray/unittest"
    22  )
    23  
    24  // Logger is a common interface that can be used to allow testing.B and
    25  // testing.T objects to be passed to the same function.
    26  type Logger interface {
    27  	Error(args ...interface{})
    28  	Errorf(format string, args ...interface{})
    29  	Failed() bool
    30  	Fatal(args ...interface{})
    31  	Fatalf(format string, args ...interface{})
    32  	Skip(args ...interface{})
    33  	Skipf(format string, args ...interface{})
    34  	Log(args ...interface{})
    35  	Logf(format string, args ...interface{})
    36  }
    37  
    38  // Backtracer is an interface that provies additional information to be
    39  // displayed using the TestExpectSuccess() functions. For an example see
    40  // BackError in the apcera/cfg package.
    41  type Backtracer interface {
    42  	Backtrace() []string
    43  }
    44  
    45  // -----------------------------------------------------------------------
    46  // Initialization, cleanup, and shutdown functions.
    47  // -----------------------------------------------------------------------
    48  
    49  // If this flag is set to true then output will be displayed live as it
    50  // happens rather than being buffered and only displayed when tests fail.
    51  var streamTestOutput bool
    52  
    53  // If a -log or log is provided with an path to a directory then that path is
    54  // available in this variable. This is a helper for tests that wish to log. An
    55  // empty string indicates the path was not set. The value is set only to allow
    56  // callers to make use of in their tests. There are no other side effects.
    57  var TestLogFile string = ""
    58  
    59  func init() {
    60  	if f := flag.Lookup("log"); f == nil {
    61  		flag.StringVar(
    62  			&TestLogFile,
    63  			"log",
    64  			"",
    65  			"Specifies the log file for the test")
    66  	}
    67  	if f := flag.Lookup("live-output"); f == nil {
    68  		flag.BoolVar(
    69  			&streamTestOutput,
    70  			"live-output",
    71  			false,
    72  			"Enable output to be streamed live rather than buffering.")
    73  	}
    74  }
    75  
    76  // TestTool type allows for parallel tests.
    77  type TestTool struct {
    78  	testing.TB
    79  
    80  	// Stores output from the logging system so it can be written only if
    81  	// the test actually fails.
    82  	LogBuffer *unittest.LogBuffer
    83  
    84  	// This is a list of functions that will be run on test completion. Having
    85  	// this allows us to clean up temporary directories or files after the
    86  	// test is done which is a huge win.
    87  	Finalizers []func()
    88  
    89  	// Parameters contains test-specific caches of data.
    90  	Parameters map[string]interface{}
    91  
    92  	RandomTestString string
    93  	PackageHash      string
    94  
    95  	*TestData
    96  }
    97  
    98  // AddTestFinalizer adds a function to be called once the test finishes.
    99  func (tt *TestTool) AddTestFinalizer(f func()) {
   100  	tt.Finalizers = append(tt.Finalizers, f)
   101  }
   102  
   103  // StartTest should be called at the start of a test to setup all the various
   104  // state bits that are needed.
   105  func StartTest(tb testing.TB) *TestTool {
   106  	tt := TestTool{
   107  		Parameters:       make(map[string]interface{}),
   108  		TB:               tb,
   109  		RandomTestString: RandomTestString(10),
   110  	}
   111  
   112  	tt.TestData = GetTestData(tb)
   113  
   114  	if tt.TestData == nil {
   115  		panic("Failed to read information about the test.")
   116  	}
   117  
   118  	tt.PackageHash = tt.Package + hashPackage(tt.PackageDir)
   119  
   120  	if !streamTestOutput {
   121  		tt.LogBuffer = unittest.SetupBuffer()
   122  	} else {
   123  		logray.AddDefaultOutput("stdout://", logray.ALL)
   124  	}
   125  
   126  	return &tt
   127  }
   128  
   129  // FinishTest is called as a defer to a test in order to clean up after a test
   130  // run. All tests in this module should call this function as a defer right
   131  // after calling StartTest()
   132  func (tt *TestTool) FinishTest() {
   133  	for i := len(tt.Finalizers) - 1; i >= 0; i-- {
   134  		tt.Finalizers[i]()
   135  	}
   136  	tt.Finalizers = nil
   137  	if tt.LogBuffer != nil {
   138  		tt.LogBuffer.FinishTest(tt.TB)
   139  	}
   140  }
   141  
   142  // TestRequiresRoot is called to require that your test is run as root. NOTICE:
   143  // this does not cause the test to FAIL. This seems like the most sane thing to
   144  // do based on the shortcomings of Go's test utilities.
   145  //
   146  // As an added feature this function will append all skipped test names into
   147  // the file name specified in the environment variable:
   148  //   $SKIPPED_ROOT_TESTS_FILE
   149  func TestRequiresRoot(l Logger) {
   150  	getTestName := func() string {
   151  		// Maximum function depth. This shouldn't be called when the stack is
   152  		// 1024 calls deep (its typically called at the top of the Test).
   153  		pc := make([]uintptr, 1024)
   154  		callers := runtime.Callers(2, pc)
   155  		testname := ""
   156  		for i := 0; i < callers; i++ {
   157  			if f := runtime.FuncForPC(pc[i]); f != nil {
   158  				// Function names have the following formats:
   159  				//   runtime.goexit
   160  				//   testing.tRunner
   161  				//   github.com/util/testtool.TestRequiresRoot
   162  				// To find the real function name we split on . and take the
   163  				// last element.
   164  				names := strings.Split(f.Name(), ".")
   165  				if strings.HasPrefix(names[len(names)-1], "Test") {
   166  					testname = names[len(names)-1]
   167  				}
   168  			}
   169  		}
   170  		if testname == "" {
   171  			Fatalf(l, "Can't figure out the test name.")
   172  		}
   173  		return testname
   174  	}
   175  
   176  	if os.Getuid() != 0 {
   177  		// We support the ability to set an environment variables where the
   178  		// names of all skipped tests will be logged. This is used to ensure
   179  		// that they can be run with sudo later.
   180  		fn := os.Getenv("SKIPPED_ROOT_TESTS_FILE")
   181  		if fn != "" {
   182  			// Get the test name. We do this using the runtime package. The
   183  			// first function named Test* we assume is the outer test function
   184  			// which is in turn the test name.
   185  			flags := os.O_WRONLY | os.O_APPEND | os.O_CREATE
   186  			f, err := os.OpenFile(fn, flags, os.FileMode(0644))
   187  			TestExpectSuccess(l, err)
   188  			defer f.Close()
   189  			_, err = f.WriteString(getTestName() + "\n")
   190  			TestExpectSuccess(l, err)
   191  		}
   192  
   193  		l.Skipf("This test must be run as root. Skipping.")
   194  	}
   195  }
   196  
   197  // -----------------------------------------------------------------------
   198  // Temporary file helpers.
   199  // -----------------------------------------------------------------------
   200  
   201  // WriteTempFile writes contents to a temporary file, sets up a Finalizer to
   202  // remove the file once the test is complete, and then returns the newly created
   203  // filename to the caller.
   204  func (tt *TestTool) WriteTempFile(contents string) string {
   205  	return tt.WriteTempFileMode(contents, os.FileMode(0644))
   206  }
   207  
   208  // WriteTempFileMode is like WriteTempFile but sets the mode.
   209  func (tt *TestTool) WriteTempFileMode(contents string, mode os.FileMode) string {
   210  	f, err := ioutil.TempFile("", "golangunittest")
   211  	if f == nil {
   212  		Fatalf(tt.TB, "ioutil.TempFile() return nil.")
   213  	} else if err != nil {
   214  		Fatalf(tt.TB, "ioutil.TempFile() return an err: %s", err)
   215  	} else if err := os.Chmod(f.Name(), mode); err != nil {
   216  		Fatalf(tt.TB, "os.Chmod() returned an error: %s", err)
   217  	}
   218  	defer f.Close()
   219  	tt.Finalizers = append(tt.Finalizers, func() {
   220  		os.Remove(f.Name())
   221  	})
   222  	contentsBytes := []byte(contents)
   223  	n, err := f.Write(contentsBytes)
   224  	if err != nil {
   225  		Fatalf(tt.TB, "Error writing to %s: %s", f.Name(), err)
   226  	} else if n != len(contentsBytes) {
   227  		Fatalf(tt.TB, "Short write to %s", f.Name())
   228  	}
   229  	return f.Name()
   230  }
   231  
   232  // TempDir makes a temporary directory.
   233  func (tt *TestTool) TempDir() string {
   234  	return tt.TempDirMode(os.FileMode(0755))
   235  }
   236  
   237  // TempDirMode makes a temporary directory with the given mode.
   238  func (tt *TestTool) TempDirMode(mode os.FileMode) string {
   239  	f, err := ioutil.TempDir(RootTempDir(tt), "golangunittest")
   240  	if f == "" {
   241  		Fatalf(tt.TB, "ioutil.TempFile() return an empty string.")
   242  	} else if err != nil {
   243  		Fatalf(tt.TB, "ioutil.TempFile() return an err: %s", err)
   244  	} else if err := os.Chmod(f, mode); err != nil {
   245  		Fatalf(tt.TB, "os.Chmod failure.")
   246  	}
   247  
   248  	tt.Finalizers = append(tt.Finalizers, func() {
   249  		os.RemoveAll(f)
   250  	})
   251  	return f
   252  }
   253  
   254  // TempFile allocate a temporary file and ensures that it gets cleaned up when
   255  // the test is completed.
   256  func (tt *TestTool) TempFile() string {
   257  	return tt.TempFileMode(os.FileMode(0644))
   258  }
   259  
   260  // TempFileMode writes a temp file with the given mode.
   261  func (tt *TestTool) TempFileMode(mode os.FileMode) string {
   262  	f, err := ioutil.TempFile(RootTempDir(tt), "unittest")
   263  	if err != nil {
   264  		Fatalf(tt.TB, "Error making temporary file: %s", err)
   265  	} else if err := os.Chmod(f.Name(), mode); err != nil {
   266  		Fatalf(tt.TB, "os.Chmod failure.")
   267  	}
   268  	defer f.Close()
   269  	name := f.Name()
   270  	tt.Finalizers = append(tt.Finalizers, func() {
   271  		os.RemoveAll(name)
   272  	})
   273  	return name
   274  }
   275  
   276  // -----------------------------------------------------------------------
   277  // Fatalf wrapper.
   278  // -----------------------------------------------------------------------
   279  
   280  // Fatalf wraps Fatalf in order to provide a functional stack trace on failures
   281  // rather than just a line number of the failing check. This helps if you have a
   282  // test that fails in a loop since it will show the path to get there as well as
   283  // the error directly.
   284  func Fatalf(l Logger, f string, args ...interface{}) {
   285  	lines := make([]string, 0, 100)
   286  	msg := fmt.Sprintf(f, args...)
   287  	lines = append(lines, msg)
   288  
   289  	// Get the directory of testtool in order to ensure that we don't show
   290  	// it in the stack traces (it can be spammy).
   291  	_, myfile, _, _ := runtime.Caller(0)
   292  	mydir := path.Dir(myfile)
   293  
   294  	// Generate the Stack of callers:
   295  	for i := 0; true; i++ {
   296  		_, file, line, ok := runtime.Caller(i)
   297  		if ok == false {
   298  			break
   299  		}
   300  		// Don't print the stack line if its within testtool since its
   301  		// annoying to see the testtool internals.
   302  		if path.Dir(file) == mydir {
   303  			continue
   304  		}
   305  		msg := fmt.Sprintf("%d - %s:%d", i, file, line)
   306  		lines = append(lines, msg)
   307  	}
   308  	l.Fatalf("%s", strings.Join(lines, "\n"))
   309  }
   310  
   311  // Fatal fails the test with a simple format for the message.
   312  func Fatal(t Logger, args ...interface{}) {
   313  	Fatalf(t, "%s", fmt.Sprint(args...))
   314  }
   315  
   316  // -----------------------------------------------------------------------
   317  // Simple Timeout functions
   318  // -----------------------------------------------------------------------
   319  
   320  // Timeout runs the given function until 'timeout' has passed, sleeping 'sleep'
   321  // duration in between runs. If the function returns true this exits, otherwise
   322  // after timeout this will fail the test.
   323  func Timeout(l Logger, timeout, sleep time.Duration, f func() bool) {
   324  	end := time.Now().Add(timeout)
   325  	for time.Now().Before(end) {
   326  		if f() == true {
   327  			return
   328  		}
   329  		time.Sleep(sleep)
   330  	}
   331  	Fatalf(l, "testtool: Timeout after %v", timeout)
   332  }
   333  
   334  // -----------------------------------------------------------------------
   335  // Error object handling functions.
   336  // -----------------------------------------------------------------------
   337  
   338  // TestExpectError calls Fatal if err is nil.
   339  func TestExpectError(l Logger, err error, msg ...string) {
   340  	reason := ""
   341  	if len(msg) > 0 {
   342  		reason = ": " + strings.Join(msg, "")
   343  	}
   344  	if err == nil {
   345  		Fatalf(l, "Expected an error, got nil%s", reason)
   346  	}
   347  }
   348  
   349  // isRealError detects if a nil of a type stored as a concrete type, rather than
   350  // an error interface, was passed in.
   351  func isRealError(err error) bool {
   352  	if err == nil {
   353  		return false
   354  	}
   355  	v := reflect.ValueOf(err)
   356  	if !v.CanInterface() {
   357  		return true
   358  	}
   359  	if v.IsNil() {
   360  		return false
   361  	}
   362  	return true
   363  }
   364  
   365  // TestExpectSuccess fails the test if err is not nil and fails the test and
   366  // output the reason for the failure as the err argument the same as Fatalf. If
   367  // err implements the BackTracer interface a backtrace will also be displayed.
   368  func TestExpectSuccess(l Logger, err error, msg ...string) {
   369  	reason := ""
   370  	if len(msg) > 0 {
   371  		reason = ": " + strings.Join(msg, "")
   372  	}
   373  	if err != nil && isRealError(err) {
   374  		lines := make([]string, 0, 50)
   375  		lines = append(lines, fmt.Sprintf("Unexpected error: %s", err))
   376  		if be, ok := err.(Backtracer); ok {
   377  			for _, line := range be.Backtrace() {
   378  				lines = append(lines, fmt.Sprintf(" * %s", line))
   379  			}
   380  		}
   381  		Fatalf(l, "%s%s", strings.Join(lines, "\n"), reason)
   382  	}
   383  }
   384  
   385  // TestExpectNonZeroLength fails the test if the given value is not zero.
   386  func TestExpectZeroLength(l Logger, size int) {
   387  	if size != 0 {
   388  		Fatalf(l, "Zero length expected")
   389  	}
   390  }
   391  
   392  // TestExpectNonZeroLength fails the test if the given value is zero.
   393  func TestExpectNonZeroLength(l Logger, size int) {
   394  	if size == 0 {
   395  		Fatalf(l, "Zero length found")
   396  	}
   397  }
   398  
   399  // TestExpectPanic verifies that a panic is called with the expected msg.
   400  func TestExpectPanic(l Logger, f func(), msg string) {
   401  	defer func(msg string) {
   402  		if m := recover(); m != nil {
   403  			TestEqual(l, msg, m)
   404  		}
   405  	}(msg)
   406  	f()
   407  	Fatalf(l, "Expected a panic with message '%s'\n", msg)
   408  }
   409  
   410  var letters = []rune("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ")
   411  
   412  // RandomTestString generates a random test string from only upper and lower
   413  // case letters.
   414  func RandomTestString(n int) string {
   415  	b := make([]rune, n)
   416  	for i := range b {
   417  		b[i] = letters[rand.Intn(len(letters))]
   418  	}
   419  
   420  	return string(b)
   421  }
   422  
   423  var encoder = base64.NewEncoding("ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789__")
   424  
   425  func hashPackage(pkg string) string {
   426  	hash := md5.New()
   427  	hash.Write([]byte(pkg))
   428  	out := encoder.EncodeToString(hash.Sum(nil))
   429  	//Ensure alphanumber prefix and remove base64 padding
   430  	return "p" + out[:len(out)-2]
   431  }