github.com/blend/go-sdk@v1.20220411.3/assert/assert.go (about)

     1  /*
     2  
     3  Copyright (c) 2022 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package assert
     9  
    10  import (
    11  	"context"
    12  	"fmt"
    13  	"io"
    14  	"math"
    15  	"math/rand"
    16  	"os"
    17  	"reflect"
    18  	"regexp"
    19  	"runtime"
    20  	"strings"
    21  	"sync/atomic"
    22  	"testing"
    23  	"time"
    24  	"unicode"
    25  	"unicode/utf8"
    26  )
    27  
    28  // Empty returns an empty assertions handler; useful when you want to apply assertions w/o hooking into the testing framework.
    29  func Empty(opts ...Option) *Assertions {
    30  	a := Assertions{
    31  		OutputFormat: OutputFormatFromEnv(),
    32  		Context:      WithContextID(context.Background(), randomString(8)),
    33  	}
    34  	for _, opt := range opts {
    35  		opt(&a)
    36  	}
    37  	return &a
    38  }
    39  
    40  // New returns a new instance of `Assertions`.
    41  func New(t *testing.T, opts ...Option) *Assertions {
    42  	a := Assertions{
    43  		T:            t,
    44  		OutputFormat: OutputFormatFromEnv(),
    45  		Context:      WithContextID(context.Background(), randomString(8)),
    46  	}
    47  	if t != nil {
    48  		a.Context = WithTestName(a.Context, t.Name())
    49  	}
    50  	for _, opt := range opts {
    51  		opt(&a)
    52  	}
    53  	return &a
    54  }
    55  
    56  // Assertions is the main entry point for using the assertions library.
    57  type Assertions struct {
    58  	Output       io.Writer
    59  	OutputFormat OutputFormat
    60  	T            *testing.T
    61  	Context      context.Context
    62  	Optional     bool
    63  	Count        int32
    64  }
    65  
    66  // Background returns the assertions context.
    67  func (a *Assertions) Background() context.Context {
    68  	return a.Context
    69  }
    70  
    71  // assertion represents the actions to take for *each* assertion.
    72  // it is used internally for stats tracking.
    73  func (a *Assertions) assertion() {
    74  	atomic.AddInt32(&a.Count, 1)
    75  }
    76  
    77  // NonFatal transitions the assertion into a `NonFatal` assertion; that is, one that will not cause the test to abort if it fails.
    78  // NonFatal assertions are useful when you want to check many properties during a test, but only on an informational basis.
    79  // They will typically return a bool to indicate if the assertion succeeded, or if you should consider the overall
    80  // test to still be a success.
    81  func (a *Assertions) NonFatal() *Assertions { //golint you can bite me.
    82  	return &Assertions{
    83  		T:            a.T,
    84  		Output:       a.Output,
    85  		OutputFormat: a.OutputFormat,
    86  		Optional:     true,
    87  	}
    88  }
    89  
    90  // NotImplemented will just error.
    91  func (a *Assertions) NotImplemented(userMessageComponents ...interface{}) {
    92  	fail(a.Output, a.T, a.OutputFormat, NewFailure("the current test is not implemented", userMessageComponents...))
    93  }
    94  
    95  func (a *Assertions) fail(message string, userMessageComponents ...interface{}) bool {
    96  	if a.Optional {
    97  		fail(a.Output, a.T, a.OutputFormat, NewFailure(message, userMessageComponents...))
    98  		return false
    99  	}
   100  	failNow(a.Output, a.T, a.OutputFormat, NewFailure(message, userMessageComponents...))
   101  	return false
   102  }
   103  
   104  // NotNil asserts that a reference is not nil.
   105  func (a *Assertions) NotNil(object interface{}, userMessageComponents ...interface{}) bool {
   106  	a.assertion()
   107  	if didFail, message := shouldNotBeNil(object); didFail {
   108  		return a.fail(message, userMessageComponents...)
   109  	}
   110  	return true
   111  }
   112  
   113  // Nil asserts that a reference is nil.
   114  func (a *Assertions) Nil(object interface{}, userMessageComponents ...interface{}) bool {
   115  	a.assertion()
   116  	if didFail, message := shouldBeNil(object); didFail {
   117  		return a.fail(message, userMessageComponents...)
   118  	}
   119  	return true
   120  }
   121  
   122  // Len asserts that a collection has a given length.
   123  func (a *Assertions) Len(collection interface{}, length int, userMessageComponents ...interface{}) bool {
   124  	a.assertion()
   125  	if didFail, message := shouldHaveLength(collection, length); didFail {
   126  		return a.fail(message, userMessageComponents...)
   127  	}
   128  	return true
   129  }
   130  
   131  // Empty asserts that a collection is empty.
   132  func (a *Assertions) Empty(collection interface{}, userMessageComponents ...interface{}) bool {
   133  	a.assertion()
   134  	if didFail, message := shouldBeEmpty(collection); didFail {
   135  		return a.fail(message, userMessageComponents...)
   136  	}
   137  	return true
   138  }
   139  
   140  // NotEmpty asserts that a collection is not empty.
   141  func (a *Assertions) NotEmpty(collection interface{}, userMessageComponents ...interface{}) bool {
   142  	a.assertion()
   143  	if didFail, message := shouldNotBeEmpty(collection); didFail {
   144  		return a.fail(message, userMessageComponents...)
   145  	}
   146  	return true
   147  }
   148  
   149  // Equal asserts that two objects are deeply equal.
   150  func (a *Assertions) Equal(expected interface{}, actual interface{}, userMessageComponents ...interface{}) bool {
   151  	a.assertion()
   152  	if didFail, message := shouldBeEqual(expected, actual); didFail {
   153  		return a.fail(message, userMessageComponents...)
   154  	}
   155  	return true
   156  }
   157  
   158  // ReferenceEqual asserts that two objects are the same reference in memory.
   159  func (a *Assertions) ReferenceEqual(expected interface{}, actual interface{}, userMessageComponents ...interface{}) bool {
   160  	a.assertion()
   161  	if didFail, message := shouldBeReferenceEqual(expected, actual); didFail {
   162  		return a.fail(message, userMessageComponents...)
   163  	}
   164  	return true
   165  }
   166  
   167  // NotEqual asserts that two objects are not deeply equal.
   168  func (a *Assertions) NotEqual(expected interface{}, actual interface{}, userMessageComponents ...interface{}) bool {
   169  	a.assertion()
   170  	if didFail, message := shouldNotBeEqual(expected, actual); didFail {
   171  		return a.fail(message, userMessageComponents...)
   172  	}
   173  	return true
   174  }
   175  
   176  // PanicEqual asserts the panic emitted by an action equals an expected value.
   177  func (a *Assertions) PanicEqual(expected interface{}, action func(), userMessageComponents ...interface{}) bool {
   178  	a.assertion()
   179  	if didFail, message := shouldBePanicEqual(expected, action); didFail {
   180  		return a.fail(message, userMessageComponents...)
   181  	}
   182  	return true
   183  }
   184  
   185  // NotPanic asserts the given action does not panic.
   186  func (a *Assertions) NotPanic(action func(), userMessageComponents ...interface{}) bool {
   187  	a.assertion()
   188  	if didFail, message := shouldNotPanic(action); didFail {
   189  		return a.fail(message, userMessageComponents...)
   190  	}
   191  	return true
   192  }
   193  
   194  // Zero asserts that a value is equal to it's default value.
   195  func (a *Assertions) Zero(value interface{}, userMessageComponents ...interface{}) bool {
   196  	a.assertion()
   197  	if didFail, message := shouldBeZero(value); didFail {
   198  		return a.fail(message, userMessageComponents...)
   199  	}
   200  	return true
   201  }
   202  
   203  // NotZero asserts that a value is not equal to it's default value.
   204  func (a *Assertions) NotZero(value interface{}, userMessageComponents ...interface{}) bool {
   205  	a.assertion()
   206  	if didFail, message := shouldBeNonZero(value); didFail {
   207  		return a.fail(message, userMessageComponents...)
   208  	}
   209  	return true
   210  }
   211  
   212  // True asserts a boolean is true.
   213  func (a *Assertions) True(object bool, userMessageComponents ...interface{}) bool {
   214  	a.assertion()
   215  	if didFail, message := shouldBeTrue(object); didFail {
   216  		return a.fail(message, userMessageComponents...)
   217  	}
   218  	return true
   219  }
   220  
   221  // False asserts a boolean is false.
   222  func (a *Assertions) False(object bool, userMessageComponents ...interface{}) bool {
   223  	a.assertion()
   224  	if didFail, message := shouldBeFalse(object); didFail {
   225  		return a.fail(message, userMessageComponents...)
   226  	}
   227  	return true
   228  }
   229  
   230  // InDelta asserts that two floats are within a delta.
   231  //
   232  // The delta is computed by the absolute of the difference betwee `f0` and `f1`
   233  // and testing if that absolute difference is strictly less than `delta`
   234  // if greater, it will fail the assertion, if delta is equal to or greater than difference
   235  // the assertion will pass.
   236  func (a *Assertions) InDelta(f0, f1, delta float64, userMessageComponents ...interface{}) bool {
   237  	a.assertion()
   238  	if didFail, message := shouldBeInDelta(f0, f1, delta); didFail {
   239  		return a.fail(message, userMessageComponents...)
   240  	}
   241  	return true
   242  }
   243  
   244  // InTimeDelta asserts that times t1 and t2 are within a delta.
   245  func (a *Assertions) InTimeDelta(t1, t2 time.Time, delta time.Duration, userMessageComponents ...interface{}) bool {
   246  	a.assertion()
   247  	if didFail, message := shouldBeInTimeDelta(t1, t2, delta); didFail {
   248  		return a.fail(message, userMessageComponents...)
   249  	}
   250  	return true
   251  }
   252  
   253  // NotInTimeDelta asserts that times t1 and t2 are not within a delta.
   254  func (a *Assertions) NotInTimeDelta(t1, t2 time.Time, delta time.Duration, userMessageComponents ...interface{}) bool {
   255  	a.assertion()
   256  	if didFail, message := shouldNotBeInTimeDelta(t1, t2, delta); didFail {
   257  		return a.fail(message, userMessageComponents...)
   258  	}
   259  	return true
   260  }
   261  
   262  // FileExists asserts that a file exists at a given filepath on disk.
   263  func (a *Assertions) FileExists(filepath string, userMessageComponents ...interface{}) bool {
   264  	a.assertion()
   265  	if didFail, message := shouldFileExist(filepath); didFail {
   266  		return a.fail(message, userMessageComponents...)
   267  	}
   268  	return true
   269  }
   270  
   271  // Contains asserts that a substring is present in a corpus.
   272  func (a *Assertions) Contains(corpus, substring string, userMessageComponents ...interface{}) bool {
   273  	a.assertion()
   274  	if didFail, message := shouldContain(corpus, substring); didFail {
   275  		return a.fail(message, userMessageComponents...)
   276  	}
   277  	return true
   278  }
   279  
   280  // NotContains asserts that a substring is present in a corpus.
   281  func (a *Assertions) NotContains(corpus, substring string, userMessageComponents ...interface{}) bool {
   282  	a.assertion()
   283  	if didFail, message := shouldNotContain(corpus, substring); didFail {
   284  		return a.fail(message, userMessageComponents...)
   285  	}
   286  	return true
   287  }
   288  
   289  // HasPrefix asserts that a corpus has a given prefix.
   290  func (a *Assertions) HasPrefix(corpus, prefix string, userMessageComponents ...interface{}) bool {
   291  	a.assertion()
   292  	if didFail, message := shouldHasPrefix(corpus, prefix); didFail {
   293  		return a.fail(message, userMessageComponents...)
   294  	}
   295  	return true
   296  }
   297  
   298  // NotHasPrefix asserts that a corpus does not have a given prefix.
   299  func (a *Assertions) NotHasPrefix(corpus, prefix string, userMessageComponents ...interface{}) bool {
   300  	a.assertion()
   301  	if didFail, message := shouldNotHasPrefix(corpus, prefix); didFail {
   302  		return a.fail(message, userMessageComponents...)
   303  	}
   304  	return true
   305  }
   306  
   307  // HasSuffix asserts that a corpus has a given suffix.
   308  func (a *Assertions) HasSuffix(corpus, suffix string, userMessageComponents ...interface{}) bool {
   309  	a.assertion()
   310  	if didFail, message := shouldHasSuffix(corpus, suffix); didFail {
   311  		return a.fail(message, userMessageComponents...)
   312  	}
   313  	return true
   314  }
   315  
   316  // NotHasSuffix asserts that a corpus does not have a given suffix.
   317  func (a *Assertions) NotHasSuffix(corpus, suffix string, userMessageComponents ...interface{}) bool {
   318  	a.assertion()
   319  	if didFail, message := shouldNotHasSuffix(corpus, suffix); didFail {
   320  		return a.fail(message, userMessageComponents...)
   321  	}
   322  	return true
   323  }
   324  
   325  // Matches returns if a given value matches a given regexp expression.
   326  func (a *Assertions) Matches(expr string, value interface{}, userMessageComponents ...interface{}) bool {
   327  	a.assertion()
   328  	if didFail, message := shouldMatch(expr, value); didFail {
   329  		return a.fail(message, userMessageComponents...)
   330  	}
   331  	return true
   332  }
   333  
   334  // NotMatches returns if a given value does not match a given regexp expression.
   335  func (a *Assertions) NotMatches(expr string, value interface{}, userMessageComponents ...interface{}) bool {
   336  	a.assertion()
   337  	if didFail, message := shouldNotMatch(expr, value); didFail {
   338  		return a.fail(message, userMessageComponents...)
   339  	}
   340  	return true
   341  }
   342  
   343  // Any applies a predicate.
   344  func (a *Assertions) Any(target interface{}, predicate Predicate, userMessageComponents ...interface{}) bool {
   345  	a.assertion()
   346  	if didFail, message := shouldAny(target, predicate); didFail {
   347  		return a.fail(message, userMessageComponents...)
   348  	}
   349  	return true
   350  }
   351  
   352  // AnyOfInt applies a predicate.
   353  func (a *Assertions) AnyOfInt(target []int, predicate PredicateOfInt, userMessageComponents ...interface{}) bool {
   354  	a.assertion()
   355  	if didFail, message := shouldAnyOfInt(target, predicate); didFail {
   356  		return a.fail(message, userMessageComponents...)
   357  	}
   358  	return true
   359  }
   360  
   361  // AnyOfFloat64 applies a predicate.
   362  func (a *Assertions) AnyOfFloat64(target []float64, predicate PredicateOfFloat, userMessageComponents ...interface{}) bool {
   363  	a.assertion()
   364  	if didFail, message := shouldAnyOfFloat(target, predicate); didFail {
   365  		return a.fail(message, userMessageComponents...)
   366  	}
   367  	return true
   368  }
   369  
   370  // AnyOfString applies a predicate.
   371  func (a *Assertions) AnyOfString(target []string, predicate PredicateOfString, userMessageComponents ...interface{}) bool {
   372  	a.assertion()
   373  	if didFail, message := shouldAnyOfString(target, predicate); didFail {
   374  		return a.fail(message, userMessageComponents...)
   375  	}
   376  	return true
   377  }
   378  
   379  // AnyCount applies a predicate and passes if it fires a given number of times .
   380  func (a *Assertions) AnyCount(target interface{}, times int, predicate Predicate, userMessageComponents ...interface{}) bool {
   381  	a.assertion()
   382  	if didFail, message := shouldAnyCount(target, times, predicate); didFail {
   383  		return a.fail(message, userMessageComponents...)
   384  	}
   385  	return true
   386  }
   387  
   388  // All applies a predicate.
   389  func (a *Assertions) All(target interface{}, predicate Predicate, userMessageComponents ...interface{}) bool {
   390  	a.assertion()
   391  	if didFail, message := shouldAll(target, predicate); didFail {
   392  		return a.fail(message, userMessageComponents...)
   393  	}
   394  	return true
   395  }
   396  
   397  // AllOfInt applies a predicate.
   398  func (a *Assertions) AllOfInt(target []int, predicate PredicateOfInt, userMessageComponents ...interface{}) bool {
   399  	a.assertion()
   400  	if didFail, message := shouldAllOfInt(target, predicate); didFail {
   401  		return a.fail(message, userMessageComponents...)
   402  	}
   403  	return true
   404  }
   405  
   406  // AllOfFloat64 applies a predicate.
   407  func (a *Assertions) AllOfFloat64(target []float64, predicate PredicateOfFloat, userMessageComponents ...interface{}) bool {
   408  	a.assertion()
   409  	if didFail, message := shouldAllOfFloat(target, predicate); didFail {
   410  		return a.fail(message, userMessageComponents...)
   411  	}
   412  	return true
   413  }
   414  
   415  // AllOfString applies a predicate.
   416  func (a *Assertions) AllOfString(target []string, predicate PredicateOfString, userMessageComponents ...interface{}) bool {
   417  	a.assertion()
   418  	if didFail, message := shouldAllOfString(target, predicate); didFail {
   419  		return a.fail(message, userMessageComponents...)
   420  	}
   421  	return true
   422  }
   423  
   424  // None applies a predicate.
   425  func (a *Assertions) None(target interface{}, predicate Predicate, userMessageComponents ...interface{}) bool {
   426  	a.assertion()
   427  	if didFail, message := shouldNone(target, predicate); didFail {
   428  		return a.fail(message, userMessageComponents...)
   429  	}
   430  	return true
   431  }
   432  
   433  // NoneOfInt applies a predicate.
   434  func (a *Assertions) NoneOfInt(target []int, predicate PredicateOfInt, userMessageComponents ...interface{}) bool {
   435  	a.assertion()
   436  	if didFail, message := shouldNoneOfInt(target, predicate); didFail {
   437  		return a.fail(message, userMessageComponents...)
   438  	}
   439  	return true
   440  }
   441  
   442  // NoneOfFloat64 applies a predicate.
   443  func (a *Assertions) NoneOfFloat64(target []float64, predicate PredicateOfFloat, userMessageComponents ...interface{}) bool {
   444  	a.assertion()
   445  	if didFail, message := shouldNoneOfFloat(target, predicate); didFail {
   446  		return a.fail(message, userMessageComponents...)
   447  	}
   448  	return true
   449  }
   450  
   451  // NoneOfString applies a predicate.
   452  func (a *Assertions) NoneOfString(target []string, predicate PredicateOfString, userMessageComponents ...interface{}) bool {
   453  	a.assertion()
   454  	if didFail, message := shouldNoneOfString(target, predicate); didFail {
   455  		return a.fail(message, userMessageComponents...)
   456  	}
   457  	return true
   458  }
   459  
   460  // FailNow forces a test failure (useful for debugging).
   461  func (a *Assertions) FailNow(userMessageComponents ...interface{}) {
   462  	failNow(a.Output, a.T, a.OutputFormat, NewFailure("Fatal Assertion Failed", userMessageComponents...))
   463  }
   464  
   465  // Fail forces a test failure (useful for debugging).
   466  func (a *Assertions) Fail(userMessageComponents ...interface{}) bool {
   467  	fail(a.Output, a.T, a.OutputFormat, NewFailure("Fatal Assertion Failed", userMessageComponents...))
   468  	return true
   469  }
   470  
   471  // --------------------------------------------------------------------------------
   472  // OUTPUT
   473  // --------------------------------------------------------------------------------
   474  
   475  func failNow(w io.Writer, t *testing.T, outputFormat OutputFormat, failure Failure) {
   476  	fail(w, t, outputFormat, failure)
   477  	if t != nil {
   478  		t.FailNow()
   479  	} else {
   480  		panic(failure)
   481  	}
   482  }
   483  
   484  func fail(w io.Writer, t *testing.T, outputFormat OutputFormat, failure Failure) {
   485  	var output string
   486  	switch outputFormat {
   487  	case OutputFormatDefault, OutputFormatText:
   488  		output = fmt.Sprintf("\r%s", getClearString())
   489  		output += failure.Text()
   490  	case OutputFormatJSON:
   491  		output = fmt.Sprintf("\r%s", getLocationString())
   492  		output += failure.JSON()
   493  	default:
   494  		panic(fmt.Errorf("invalid output format: %s", outputFormat))
   495  	}
   496  	if t != nil {
   497  		t.Error(output)
   498  	}
   499  	if w != nil {
   500  		fmt.Fprint(w, output)
   501  	}
   502  }
   503  
   504  func callerInfoStrings(frames []stackFrame) []string {
   505  	output := make([]string, len(frames))
   506  	for index := range frames {
   507  		output[index] = frames[index].String()
   508  	}
   509  	return output
   510  }
   511  
   512  type stackFrame struct {
   513  	PC       uintptr
   514  	FileFull string
   515  	Dir      string
   516  	File     string
   517  	Name     string
   518  	Line     int
   519  	OK       bool
   520  }
   521  
   522  func (sf stackFrame) String() string {
   523  	return fmt.Sprintf("%s:%d", sf.File, sf.Line)
   524  }
   525  
   526  func callerInfo() []stackFrame {
   527  	var name string
   528  	var callers []stackFrame
   529  	for i := 0; ; i++ {
   530  		var frame stackFrame
   531  		frame.PC, frame.FileFull, frame.Line, frame.OK = runtime.Caller(i)
   532  		if !frame.OK {
   533  			break
   534  		}
   535  
   536  		if frame.FileFull == "<autogenerated>" {
   537  			break
   538  		}
   539  
   540  		parts := strings.Split(frame.FileFull, "/")
   541  		frame.Dir = parts[len(parts)-2]
   542  		frame.File = parts[len(parts)-1]
   543  		if frame.Dir != "assert" {
   544  			callers = append(callers, frame)
   545  		}
   546  
   547  		f := runtime.FuncForPC(frame.PC)
   548  		if f == nil {
   549  			break
   550  		}
   551  		name = f.Name()
   552  
   553  		// Drop the package
   554  		segments := strings.Split(name, ".")
   555  		name = segments[len(segments)-1]
   556  		if isTest(name, "Test") ||
   557  			isTest(name, "Benchmark") ||
   558  			isTest(name, "Example") {
   559  			break
   560  		}
   561  	}
   562  
   563  	return callers
   564  }
   565  
   566  func color(input string, colorCode string) string {
   567  	return fmt.Sprintf("\033[%s;01m%s\033[0m", colorCode, input)
   568  }
   569  
   570  func isTest(name, prefix string) bool {
   571  	if !strings.HasPrefix(name, prefix) {
   572  		return false
   573  	}
   574  	if len(name) == len(prefix) { // "Test" is ok
   575  		return true
   576  	}
   577  	rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
   578  	return !unicode.IsLower(rune)
   579  }
   580  
   581  func getClearString() string {
   582  	_, file, line, ok := runtime.Caller(1)
   583  	if !ok {
   584  		return ""
   585  	}
   586  	parts := strings.Split(file, "/")
   587  	file = parts[len(parts)-1]
   588  
   589  	return strings.Repeat(" ", len(fmt.Sprintf("%s:%d:      ", file, line))+2)
   590  }
   591  
   592  func getLocationString() string {
   593  	callers := callerInfo()
   594  	if len(callers) == 0 {
   595  		return ""
   596  	}
   597  	last := callers[len(callers)-1]
   598  	return fmt.Sprintf("%s:%d:      ", last.File, last.Line)
   599  }
   600  
   601  func safeExec(action func()) (err error) {
   602  	defer func() {
   603  		if r := recover(); r != nil {
   604  			err = fmt.Errorf("%v", r)
   605  		}
   606  	}()
   607  	action()
   608  	return
   609  }
   610  
   611  func randomString(length int) string {
   612  	const charset = "abcdefghijklmnopqrstuvwxyz"
   613  	b := make([]byte, length)
   614  	for i := range b {
   615  		b[i] = charset[rand.Intn(len(charset))]
   616  	}
   617  	return string(b)
   618  }
   619  
   620  // --------------------------------------------------------------------------------
   621  // ASSERTION LOGIC
   622  // --------------------------------------------------------------------------------
   623  
   624  func shouldHaveLength(collection interface{}, length int) (bool, string) {
   625  	if l := getLength(collection); l != length {
   626  		message := shouldBeMultipleMessage(length, l, "Collection should have length")
   627  		return true, message
   628  	}
   629  	return false, ""
   630  }
   631  
   632  func shouldNotBeEmpty(collection interface{}) (bool, string) {
   633  	if l := getLength(collection); l == 0 {
   634  		message := "Should not be empty"
   635  		return true, message
   636  	}
   637  	return false, ""
   638  }
   639  
   640  func shouldBeEmpty(collection interface{}) (bool, string) {
   641  	if l := getLength(collection); l != 0 {
   642  		message := shouldBeMessage(collection, "Should be empty")
   643  		return true, message
   644  	}
   645  	return false, ""
   646  }
   647  
   648  func shouldBeEqual(expected, actual interface{}) (bool, string) {
   649  	if !areEqual(expected, actual) {
   650  		return true, equalMessage(expected, actual)
   651  	}
   652  	return false, ""
   653  }
   654  
   655  func shouldBeReferenceEqual(expected, actual interface{}) (bool, string) {
   656  	if !areReferenceEqual(expected, actual) {
   657  		return true, referenceEqualMessage(expected, actual)
   658  	}
   659  	return false, ""
   660  }
   661  
   662  func shouldNotPanic(action func()) (bool, string) {
   663  	var actual interface{}
   664  	var didPanic bool
   665  	func() {
   666  		defer func() {
   667  			actual = recover()
   668  			didPanic = actual != nil
   669  		}()
   670  		action()
   671  	}()
   672  
   673  	if didPanic {
   674  		return true, notPanicMessage(actual)
   675  	}
   676  	return false, ""
   677  }
   678  
   679  func shouldBePanicEqual(expected interface{}, action func()) (bool, string) {
   680  	var actual interface{}
   681  	var didPanic bool
   682  	func() {
   683  		defer func() {
   684  			actual = recover()
   685  			didPanic = actual != nil
   686  		}()
   687  		action()
   688  	}()
   689  
   690  	if !didPanic || (didPanic && !areEqual(expected, actual)) {
   691  		return true, panicEqualMessage(didPanic, expected, actual)
   692  	}
   693  	return false, ""
   694  }
   695  
   696  func shouldNotBeEqual(expected, actual interface{}) (bool, string) {
   697  	if areEqual(expected, actual) {
   698  		return true, notEqualMessage(expected, actual)
   699  	}
   700  	return false, ""
   701  }
   702  
   703  func shouldNotBeNil(object interface{}) (bool, string) {
   704  	if isNil(object) {
   705  		return true, "Should not be nil"
   706  	}
   707  	return false, ""
   708  }
   709  
   710  func shouldBeNil(object interface{}) (bool, string) {
   711  	if !isNil(object) {
   712  		return true, shouldBeMessage(object, "Should be nil")
   713  	}
   714  	return false, ""
   715  }
   716  
   717  func shouldBeTrue(value bool) (bool, string) {
   718  	if !value {
   719  		return true, "Should be true"
   720  	}
   721  	return false, ""
   722  }
   723  
   724  func shouldBeFalse(value bool) (bool, string) {
   725  	if value {
   726  		return true, "Should be false"
   727  	}
   728  	return false, ""
   729  }
   730  
   731  func shouldBeZero(value interface{}) (bool, string) {
   732  	if !isZero(value) {
   733  		return true, shouldBeMessage(value, "Should be zero")
   734  	}
   735  	return false, ""
   736  }
   737  
   738  func shouldBeNonZero(value interface{}) (bool, string) {
   739  	if isZero(value) {
   740  		return true, "Should be non-zero"
   741  	}
   742  	return false, ""
   743  }
   744  
   745  func shouldFileExist(filePath string) (bool, string) {
   746  	_, err := os.Stat(filePath)
   747  	if err != nil {
   748  		pwd, _ := os.Getwd()
   749  		message := fmt.Sprintf("File doesnt exist: %s, `pwd`: %s", filePath, pwd)
   750  		return true, message
   751  	}
   752  	return false, ""
   753  }
   754  
   755  func shouldBeInDelta(from, to, delta float64) (bool, string) {
   756  	diff := math.Abs(from - to)
   757  	if diff > delta {
   758  		message := fmt.Sprintf("Absolute difference of %0.5f and %0.5f should be less than %0.5f", from, to, delta)
   759  		return true, message
   760  	}
   761  	return false, ""
   762  }
   763  
   764  func shouldBeInTimeDelta(from, to time.Time, delta time.Duration) (bool, string) {
   765  	var diff time.Duration
   766  	if from.After(to) {
   767  		diff = from.Sub(to)
   768  	} else {
   769  		diff = to.Sub(from)
   770  	}
   771  	if diff > delta {
   772  		message := fmt.Sprintf("Delta of %s and %s should be less than %v", from.Format(time.RFC3339), to.Format(time.RFC3339), delta)
   773  		return true, message
   774  	}
   775  	return false, ""
   776  }
   777  
   778  func shouldNotBeInTimeDelta(from, to time.Time, delta time.Duration) (bool, string) {
   779  	var diff time.Duration
   780  	if from.After(to) {
   781  		diff = from.Sub(to)
   782  	} else {
   783  		diff = to.Sub(from)
   784  	}
   785  
   786  	if diff <= delta {
   787  		message := fmt.Sprintf("Delta of %s and %s should be greater than %v", from.Format(time.RFC3339), to.Format(time.RFC3339), delta)
   788  		return true, message
   789  	}
   790  	return false, ""
   791  }
   792  
   793  func shouldMatch(pattern string, value interface{}) (bool, string) {
   794  	matched, err := regexp.MatchString(pattern, fmt.Sprint(value))
   795  	if err != nil {
   796  		panic(err)
   797  	}
   798  	if !matched {
   799  		message := fmt.Sprintf("`%v` should match `%s`", value, pattern)
   800  		return true, message
   801  	}
   802  	return false, ""
   803  }
   804  
   805  func shouldNotMatch(pattern string, value interface{}) (bool, string) {
   806  	matched, err := regexp.MatchString(pattern, fmt.Sprint(value))
   807  	if err != nil {
   808  		panic(err)
   809  	}
   810  	if matched {
   811  		message := fmt.Sprintf("`%v` should not match `%s`", value, pattern)
   812  		return true, message
   813  	}
   814  	return false, ""
   815  }
   816  
   817  func shouldContain(corpus, subString string) (bool, string) {
   818  	if !strings.Contains(corpus, subString) {
   819  		message := fmt.Sprintf("`%s` should contain `%s`", corpus, subString)
   820  		return true, message
   821  	}
   822  	return false, ""
   823  }
   824  
   825  func shouldNotContain(corpus, subString string) (bool, string) {
   826  	if strings.Contains(corpus, subString) {
   827  		message := fmt.Sprintf("`%s` should not contain `%s`", corpus, subString)
   828  		return true, message
   829  	}
   830  	return false, ""
   831  }
   832  
   833  func shouldHasPrefix(corpus, prefix string) (bool, string) {
   834  	if !strings.HasPrefix(corpus, prefix) {
   835  		message := fmt.Sprintf("`%s` should have prefix `%s`", corpus, prefix)
   836  		return true, message
   837  	}
   838  	return false, ""
   839  }
   840  
   841  func shouldNotHasPrefix(corpus, prefix string) (bool, string) {
   842  	if strings.HasPrefix(corpus, prefix) {
   843  		message := fmt.Sprintf("`%s` should not have prefix `%s`", corpus, prefix)
   844  		return true, message
   845  	}
   846  	return false, ""
   847  }
   848  
   849  func shouldHasSuffix(corpus, suffix string) (bool, string) {
   850  	if !strings.HasSuffix(corpus, suffix) {
   851  		message := fmt.Sprintf("`%s` should have suffix `%s`", corpus, suffix)
   852  		return true, message
   853  	}
   854  	return false, ""
   855  }
   856  
   857  func shouldNotHasSuffix(corpus, suffix string) (bool, string) {
   858  	if strings.HasSuffix(corpus, suffix) {
   859  		message := fmt.Sprintf("`%s` should not have suffix `%s`", corpus, suffix)
   860  		return true, message
   861  	}
   862  	return false, ""
   863  }
   864  
   865  func shouldAny(target interface{}, predicate Predicate) (bool, string) {
   866  	t := reflect.TypeOf(target)
   867  	for t.Kind() == reflect.Ptr {
   868  		t = t.Elem()
   869  	}
   870  
   871  	v := reflect.ValueOf(target)
   872  	for v.Kind() == reflect.Ptr {
   873  		v = v.Elem()
   874  	}
   875  
   876  	if t.Kind() != reflect.Slice {
   877  		return true, "`target` is not a slice"
   878  	}
   879  
   880  	for x := 0; x < v.Len(); x++ {
   881  		obj := v.Index(x).Interface()
   882  		if predicate(obj) {
   883  			return false, ""
   884  		}
   885  	}
   886  	return true, "Predicate did not fire for any element in target"
   887  }
   888  
   889  func shouldAnyCount(target interface{}, times int, predicate Predicate) (bool, string) {
   890  	t := reflect.TypeOf(target)
   891  	for t.Kind() == reflect.Ptr {
   892  		t = t.Elem()
   893  	}
   894  
   895  	v := reflect.ValueOf(target)
   896  	for v.Kind() == reflect.Ptr {
   897  		v = v.Elem()
   898  	}
   899  
   900  	if t.Kind() != reflect.Slice {
   901  		return true, "`target` is not a slice"
   902  	}
   903  
   904  	var seen int
   905  	for x := 0; x < v.Len(); x++ {
   906  		obj := v.Index(x).Interface()
   907  		if predicate(obj) {
   908  			seen++
   909  		}
   910  	}
   911  	if seen != times {
   912  		return true, shouldBeMultipleMessage(times, seen, "Predicate should fire a given number of times")
   913  	}
   914  	return false, ""
   915  }
   916  
   917  func shouldAnyOfInt(target []int, predicate PredicateOfInt) (bool, string) {
   918  	v := reflect.ValueOf(target)
   919  
   920  	for x := 0; x < v.Len(); x++ {
   921  		obj := v.Index(x).Interface().(int)
   922  		if predicate(obj) {
   923  			return false, ""
   924  		}
   925  	}
   926  	return true, "Predicate did not fire for any element in target"
   927  }
   928  
   929  func shouldAnyOfFloat(target []float64, predicate PredicateOfFloat) (bool, string) {
   930  	v := reflect.ValueOf(target)
   931  
   932  	for x := 0; x < v.Len(); x++ {
   933  		obj := v.Index(x).Interface().(float64)
   934  		if predicate(obj) {
   935  			return false, ""
   936  		}
   937  	}
   938  	return true, "Predicate did not fire for any element in target"
   939  }
   940  
   941  func shouldAnyOfString(target []string, predicate PredicateOfString) (bool, string) {
   942  	v := reflect.ValueOf(target)
   943  
   944  	for x := 0; x < v.Len(); x++ {
   945  		obj := v.Index(x).Interface().(string)
   946  		if predicate(obj) {
   947  			return false, ""
   948  		}
   949  	}
   950  	return true, "Predicate did not fire for any element in target"
   951  }
   952  
   953  func shouldAll(target interface{}, predicate Predicate) (bool, string) {
   954  	t := reflect.TypeOf(target)
   955  	for t.Kind() == reflect.Ptr {
   956  		t = t.Elem()
   957  	}
   958  
   959  	v := reflect.ValueOf(target)
   960  	for v.Kind() == reflect.Ptr {
   961  		v = v.Elem()
   962  	}
   963  
   964  	if t.Kind() != reflect.Slice {
   965  		return true, "`target` is not a slice"
   966  	}
   967  
   968  	for x := 0; x < v.Len(); x++ {
   969  		obj := v.Index(x).Interface()
   970  		if !predicate(obj) {
   971  			return true, fmt.Sprintf("Predicate failed for element in target: %#v", obj)
   972  		}
   973  	}
   974  	return false, ""
   975  }
   976  
   977  func shouldAllOfInt(target []int, predicate PredicateOfInt) (bool, string) {
   978  	v := reflect.ValueOf(target)
   979  
   980  	for x := 0; x < v.Len(); x++ {
   981  		obj := v.Index(x).Interface().(int)
   982  		if !predicate(obj) {
   983  			return true, fmt.Sprintf("Predicate failed for element in target: %#v", obj)
   984  		}
   985  	}
   986  	return false, ""
   987  }
   988  
   989  func shouldAllOfFloat(target []float64, predicate PredicateOfFloat) (bool, string) {
   990  	v := reflect.ValueOf(target)
   991  
   992  	for x := 0; x < v.Len(); x++ {
   993  		obj := v.Index(x).Interface().(float64)
   994  		if !predicate(obj) {
   995  			return true, fmt.Sprintf("Predicate failed for element in target: %#v", obj)
   996  		}
   997  	}
   998  	return false, ""
   999  }
  1000  
  1001  func shouldAllOfString(target []string, predicate PredicateOfString) (bool, string) {
  1002  	v := reflect.ValueOf(target)
  1003  
  1004  	for x := 0; x < v.Len(); x++ {
  1005  		obj := v.Index(x).Interface().(string)
  1006  		if !predicate(obj) {
  1007  			return true, fmt.Sprintf("Predicate failed for element in target: %#v", obj)
  1008  		}
  1009  	}
  1010  	return false, ""
  1011  }
  1012  
  1013  func shouldNone(target interface{}, predicate Predicate) (bool, string) {
  1014  	t := reflect.TypeOf(target)
  1015  	for t.Kind() == reflect.Ptr {
  1016  		t = t.Elem()
  1017  	}
  1018  
  1019  	v := reflect.ValueOf(target)
  1020  	for v.Kind() == reflect.Ptr {
  1021  		v = v.Elem()
  1022  	}
  1023  
  1024  	if t.Kind() != reflect.Slice {
  1025  		return true, "`target` is not a slice"
  1026  	}
  1027  
  1028  	for x := 0; x < v.Len(); x++ {
  1029  		obj := v.Index(x).Interface()
  1030  		if predicate(obj) {
  1031  			return true, fmt.Sprintf("Predicate passed for element in target: %#v", obj)
  1032  		}
  1033  	}
  1034  	return false, ""
  1035  }
  1036  
  1037  func shouldNoneOfInt(target []int, predicate PredicateOfInt) (bool, string) {
  1038  	v := reflect.ValueOf(target)
  1039  
  1040  	for x := 0; x < v.Len(); x++ {
  1041  		obj := v.Index(x).Interface().(int)
  1042  		if predicate(obj) {
  1043  			return true, fmt.Sprintf("Predicate passed for element in target: %#v", obj)
  1044  		}
  1045  	}
  1046  	return false, ""
  1047  }
  1048  
  1049  func shouldNoneOfFloat(target []float64, predicate PredicateOfFloat) (bool, string) {
  1050  	v := reflect.ValueOf(target)
  1051  
  1052  	for x := 0; x < v.Len(); x++ {
  1053  		obj := v.Index(x).Interface().(float64)
  1054  		if predicate(obj) {
  1055  			return true, fmt.Sprintf("Predicate passed for element in target: %#v", obj)
  1056  		}
  1057  	}
  1058  	return false, ""
  1059  }
  1060  
  1061  func shouldNoneOfString(target []string, predicate PredicateOfString) (bool, string) {
  1062  	v := reflect.ValueOf(target)
  1063  
  1064  	for x := 0; x < v.Len(); x++ {
  1065  		obj := v.Index(x).Interface().(string)
  1066  		if predicate(obj) {
  1067  			return true, fmt.Sprintf("Predicate passed for element in target: %#v", obj)
  1068  		}
  1069  	}
  1070  	return false, ""
  1071  }
  1072  
  1073  // --------------------------------------------------------------------------------
  1074  // UTILITY
  1075  // --------------------------------------------------------------------------------
  1076  
  1077  func shouldBeMultipleMessage(expected, actual interface{}, message string) string {
  1078  	expectedLabel := color("Expected", WHITE)
  1079  	actualLabel := color("Actual", WHITE)
  1080  
  1081  	return fmt.Sprintf(`%s
  1082  	%s: 	%#v
  1083  	%s: 	%#v`, message, expectedLabel, expected, actualLabel, actual)
  1084  }
  1085  
  1086  func shouldBeMessage(object interface{}, message string) string {
  1087  	actualLabel := color("Actual", WHITE)
  1088  	if err, ok := object.(error); ok {
  1089  		return fmt.Sprintf(`%s
  1090  	%s: 	%+v`, message, actualLabel, err)
  1091  	}
  1092  	return fmt.Sprintf(`%s
  1093  	%s: 	%#v`, message, actualLabel, object)
  1094  }
  1095  
  1096  func notEqualMessage(expected, actual interface{}) string {
  1097  	return shouldBeMultipleMessage(expected, actual, "Objects should not be equal")
  1098  }
  1099  
  1100  func equalMessage(expected, actual interface{}) string {
  1101  	return shouldBeMultipleMessage(expected, actual, "Objects should be equal")
  1102  }
  1103  
  1104  func referenceEqualMessage(expected, actual interface{}) string {
  1105  	return shouldBeMultipleMessage(expected, actual, "References should be equal")
  1106  }
  1107  
  1108  func panicEqualMessage(didPanic bool, expected, actual interface{}) string {
  1109  	if !didPanic {
  1110  		return "Should have produced a panic"
  1111  	}
  1112  	return shouldBeMultipleMessage(expected, actual, "Panic from action should equal")
  1113  }
  1114  
  1115  func notPanicMessage(actual interface{}) string {
  1116  	return shouldBeMessage(actual, "Should not have panicked")
  1117  }
  1118  
  1119  func getLength(object interface{}) int {
  1120  	if object == nil {
  1121  		return 0
  1122  	} else if object == "" {
  1123  		return 0
  1124  	}
  1125  
  1126  	objValue := reflect.ValueOf(object)
  1127  
  1128  	switch objValue.Kind() {
  1129  	case reflect.Map:
  1130  		fallthrough
  1131  	case reflect.Slice, reflect.Chan, reflect.String:
  1132  		{
  1133  			return objValue.Len()
  1134  		}
  1135  	}
  1136  	return 0
  1137  }
  1138  
  1139  func isNil(object interface{}) bool {
  1140  	if object == nil {
  1141  		return true
  1142  	}
  1143  
  1144  	value := reflect.ValueOf(object)
  1145  	kind := value.Kind()
  1146  	if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() {
  1147  		return true
  1148  	}
  1149  	return false
  1150  }
  1151  
  1152  func isZero(value interface{}) bool {
  1153  	return areEqual(0, value)
  1154  }
  1155  
  1156  func areReferenceEqual(expected, actual interface{}) bool {
  1157  	if expected == nil && actual == nil {
  1158  		return true
  1159  	}
  1160  	if (expected == nil && actual != nil) || (expected != nil && actual == nil) {
  1161  		return false
  1162  	}
  1163  
  1164  	return expected == actual
  1165  }
  1166  
  1167  func areEqual(expected, actual interface{}) bool {
  1168  	if expected == nil && actual == nil {
  1169  		return true
  1170  	}
  1171  	if (expected == nil && actual != nil) || (expected != nil && actual == nil) {
  1172  		return false
  1173  	}
  1174  
  1175  	actualType := reflect.TypeOf(actual)
  1176  	if actualType == nil {
  1177  		return false
  1178  	}
  1179  	expectedValue := reflect.ValueOf(expected)
  1180  	if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
  1181  		return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
  1182  	}
  1183  
  1184  	return reflect.DeepEqual(expected, actual)
  1185  }