github.com/blend/go-sdk@v1.20240719.1/assert/assert.go (about)

     1  /*
     2  
     3  Copyright (c) 2024 - Present. Blend Labs, Inc. All rights reserved
     4  Use of this source code is governed by a MIT license that can be found in the LICENSE file.
     5  
     6  */
     7  
     8  package assert
     9  
    10  import (
    11  	"context"
    12  	"fmt"
    13  	"io"
    14  	"math"
    15  	"math/rand"
    16  	"os"
    17  	"reflect"
    18  	"regexp"
    19  	"runtime"
    20  	"strings"
    21  	"sync/atomic"
    22  	"testing"
    23  	"time"
    24  	"unicode"
    25  	"unicode/utf8"
    26  )
    27  
    28  // Empty returns an empty assertions handler; useful when you want to apply assertions w/o hooking into the testing framework.
    29  func Empty(opts ...Option) *Assertions {
    30  	a := Assertions{
    31  		OutputFormat: OutputFormatFromEnv(),
    32  		Context:      WithContextID(context.Background(), randomString(8)),
    33  	}
    34  	for _, opt := range opts {
    35  		opt(&a)
    36  	}
    37  	return &a
    38  }
    39  
    40  // New returns a new instance of `Assertions`.
    41  func New(t *testing.T, opts ...Option) *Assertions {
    42  	a := Assertions{
    43  		T:            t,
    44  		OutputFormat: OutputFormatFromEnv(),
    45  		Context:      WithContextID(context.Background(), randomString(8)),
    46  	}
    47  	if t != nil {
    48  		a.Context = WithTestName(a.Context, t.Name())
    49  	}
    50  	for _, opt := range opts {
    51  		opt(&a)
    52  	}
    53  	return &a
    54  }
    55  
    56  // Assertions is the main entry point for using the assertions library.
    57  type Assertions struct {
    58  	Output       io.Writer
    59  	OutputFormat OutputFormat
    60  	T            *testing.T
    61  	Context      context.Context
    62  	Optional     bool
    63  	Count        int32
    64  }
    65  
    66  // Background returns the assertions context.
    67  func (a *Assertions) Background() context.Context {
    68  	return a.Context
    69  }
    70  
    71  // assertion represents the actions to take for *each* assertion.
    72  // it is used internally for stats tracking.
    73  func (a *Assertions) assertion() {
    74  	atomic.AddInt32(&a.Count, 1)
    75  }
    76  
    77  // NonFatal transitions the assertion into a `NonFatal` assertion; that is, one that will not cause the test to abort if it fails.
    78  // NonFatal assertions are useful when you want to check many properties during a test, but only on an informational basis.
    79  // They will typically return a bool to indicate if the assertion succeeded, or if you should consider the overall
    80  // test to still be a success.
    81  func (a *Assertions) NonFatal() *Assertions {
    82  	return &Assertions{
    83  		T:            a.T,
    84  		Output:       a.Output,
    85  		OutputFormat: a.OutputFormat,
    86  		Optional:     true,
    87  	}
    88  }
    89  
    90  // NotImplemented will just error.
    91  func (a *Assertions) NotImplemented(userMessageComponents ...interface{}) {
    92  	fail(a.Output, a.T, a.OutputFormat, NewFailure("the current test is not implemented", userMessageComponents...))
    93  }
    94  
    95  func (a *Assertions) fail(message string, userMessageComponents ...interface{}) bool {
    96  	if a.Optional {
    97  		fail(a.Output, a.T, a.OutputFormat, NewFailure(message, userMessageComponents...))
    98  		return false
    99  	}
   100  	failNow(a.Output, a.T, a.OutputFormat, NewFailure(message, userMessageComponents...))
   101  	return false
   102  }
   103  
   104  // NotNil asserts that a reference is not nil.
   105  func (a *Assertions) NotNil(object interface{}, userMessageComponents ...interface{}) bool {
   106  	a.assertion()
   107  	if didFail, message := shouldNotBeNil(object); didFail {
   108  		return a.fail(message, userMessageComponents...)
   109  	}
   110  	return true
   111  }
   112  
   113  // Nil asserts that a reference is nil.
   114  func (a *Assertions) Nil(object interface{}, userMessageComponents ...interface{}) bool {
   115  	a.assertion()
   116  	if didFail, message := shouldBeNil(object); didFail {
   117  		return a.fail(message, userMessageComponents...)
   118  	}
   119  	return true
   120  }
   121  
   122  // Len asserts that a collection has a given length.
   123  func (a *Assertions) Len(collection interface{}, length int, userMessageComponents ...interface{}) bool {
   124  	a.assertion()
   125  	if didFail, message := shouldHaveLength(collection, length); didFail {
   126  		return a.fail(message, userMessageComponents...)
   127  	}
   128  	return true
   129  }
   130  
   131  // Empty asserts that a collection is empty.
   132  func (a *Assertions) Empty(collection interface{}, userMessageComponents ...interface{}) bool {
   133  	a.assertion()
   134  	if didFail, message := shouldBeEmpty(collection); didFail {
   135  		return a.fail(message, userMessageComponents...)
   136  	}
   137  	return true
   138  }
   139  
   140  // EmptyBufferedChannel asserts that a channel is buffered (has a non-zero capacity),
   141  // and that it is empty (has length zero). This is useful when using channels to mock API
   142  // interface responses; a `len(ch) == 0` check is necessary but not necessarily sufficient
   143  // to ensure consumption of the mock channel's contents because the assertion will give a
   144  // false negative when the channel is unbuffered.
   145  func (a *Assertions) EmptyBufferedChannel(ch any, userMessageComponents ...any) bool {
   146  	a.assertion()
   147  	if didFail, message := chanShouldHaveNonzeroCapacity(ch); didFail {
   148  		return a.fail(message, userMessageComponents...)
   149  	}
   150  	if didFail, message := shouldBeEmpty(ch); didFail {
   151  		return a.fail(message, userMessageComponents...)
   152  	}
   153  	return true
   154  }
   155  
   156  // NotEmpty asserts that a collection is not empty.
   157  func (a *Assertions) NotEmpty(collection interface{}, userMessageComponents ...interface{}) bool {
   158  	a.assertion()
   159  	if didFail, message := shouldNotBeEmpty(collection); didFail {
   160  		return a.fail(message, userMessageComponents...)
   161  	}
   162  	return true
   163  }
   164  
   165  // Equal asserts that two objects are deeply equal.
   166  func (a *Assertions) Equal(expected interface{}, actual interface{}, userMessageComponents ...interface{}) bool {
   167  	a.assertion()
   168  	if didFail, message := shouldBeEqual(expected, actual); didFail {
   169  		return a.fail(message, userMessageComponents...)
   170  	}
   171  	return true
   172  }
   173  
   174  // ReferenceEqual asserts that two objects are the same reference in memory.
   175  func (a *Assertions) ReferenceEqual(expected interface{}, actual interface{}, userMessageComponents ...interface{}) bool {
   176  	a.assertion()
   177  	if didFail, message := shouldBeReferenceEqual(expected, actual); didFail {
   178  		return a.fail(message, userMessageComponents...)
   179  	}
   180  	return true
   181  }
   182  
   183  // NotEqual asserts that two objects are not deeply equal.
   184  func (a *Assertions) NotEqual(expected interface{}, actual interface{}, userMessageComponents ...interface{}) bool {
   185  	a.assertion()
   186  	if didFail, message := shouldNotBeEqual(expected, actual); didFail {
   187  		return a.fail(message, userMessageComponents...)
   188  	}
   189  	return true
   190  }
   191  
   192  // PanicEqual asserts the panic emitted by an action equals an expected value.
   193  func (a *Assertions) PanicEqual(expected interface{}, action func(), userMessageComponents ...interface{}) bool {
   194  	a.assertion()
   195  	if didFail, message := shouldBePanicEqual(expected, action); didFail {
   196  		return a.fail(message, userMessageComponents...)
   197  	}
   198  	return true
   199  }
   200  
   201  // NotPanic asserts the given action does not panic.
   202  func (a *Assertions) NotPanic(action func(), userMessageComponents ...interface{}) bool {
   203  	a.assertion()
   204  	if didFail, message := shouldNotPanic(action); didFail {
   205  		return a.fail(message, userMessageComponents...)
   206  	}
   207  	return true
   208  }
   209  
   210  // Zero asserts that a value is equal to it's default value.
   211  func (a *Assertions) Zero(value interface{}, userMessageComponents ...interface{}) bool {
   212  	a.assertion()
   213  	if didFail, message := shouldBeZero(value); didFail {
   214  		return a.fail(message, userMessageComponents...)
   215  	}
   216  	return true
   217  }
   218  
   219  // NotZero asserts that a value is not equal to it's default value.
   220  func (a *Assertions) NotZero(value interface{}, userMessageComponents ...interface{}) bool {
   221  	a.assertion()
   222  	if didFail, message := shouldBeNonZero(value); didFail {
   223  		return a.fail(message, userMessageComponents...)
   224  	}
   225  	return true
   226  }
   227  
   228  // True asserts a boolean is true.
   229  func (a *Assertions) True(object bool, userMessageComponents ...interface{}) bool {
   230  	a.assertion()
   231  	if didFail, message := shouldBeTrue(object); didFail {
   232  		return a.fail(message, userMessageComponents...)
   233  	}
   234  	return true
   235  }
   236  
   237  // False asserts a boolean is false.
   238  func (a *Assertions) False(object bool, userMessageComponents ...interface{}) bool {
   239  	a.assertion()
   240  	if didFail, message := shouldBeFalse(object); didFail {
   241  		return a.fail(message, userMessageComponents...)
   242  	}
   243  	return true
   244  }
   245  
   246  // InDelta asserts that two floats are within a delta.
   247  //
   248  // The delta is computed by the absolute of the difference betwee `f0` and `f1`
   249  // and testing if that absolute difference is strictly less than `delta`
   250  // if greater, it will fail the assertion, if delta is equal to or greater than difference
   251  // the assertion will pass.
   252  func (a *Assertions) InDelta(f0, f1, delta float64, userMessageComponents ...interface{}) bool {
   253  	a.assertion()
   254  	if didFail, message := shouldBeInDelta(f0, f1, delta); didFail {
   255  		return a.fail(message, userMessageComponents...)
   256  	}
   257  	return true
   258  }
   259  
   260  // InTimeDelta asserts that times t1 and t2 are within a delta.
   261  func (a *Assertions) InTimeDelta(t1, t2 time.Time, delta time.Duration, userMessageComponents ...interface{}) bool {
   262  	a.assertion()
   263  	if didFail, message := shouldBeInTimeDelta(t1, t2, delta); didFail {
   264  		return a.fail(message, userMessageComponents...)
   265  	}
   266  	return true
   267  }
   268  
   269  // NotInTimeDelta asserts that times t1 and t2 are not within a delta.
   270  func (a *Assertions) NotInTimeDelta(t1, t2 time.Time, delta time.Duration, userMessageComponents ...interface{}) bool {
   271  	a.assertion()
   272  	if didFail, message := shouldNotBeInTimeDelta(t1, t2, delta); didFail {
   273  		return a.fail(message, userMessageComponents...)
   274  	}
   275  	return true
   276  }
   277  
   278  // FileExists asserts that a file exists at a given filepath on disk.
   279  func (a *Assertions) FileExists(filepath string, userMessageComponents ...interface{}) bool {
   280  	a.assertion()
   281  	if didFail, message := shouldFileExist(filepath); didFail {
   282  		return a.fail(message, userMessageComponents...)
   283  	}
   284  	return true
   285  }
   286  
   287  // Contains asserts that a substring is present in a corpus.
   288  func (a *Assertions) Contains(corpus, substring string, userMessageComponents ...interface{}) bool {
   289  	a.assertion()
   290  	if didFail, message := shouldContain(corpus, substring); didFail {
   291  		return a.fail(message, userMessageComponents...)
   292  	}
   293  	return true
   294  }
   295  
   296  // NotContains asserts that a substring is present in a corpus.
   297  func (a *Assertions) NotContains(corpus, substring string, userMessageComponents ...interface{}) bool {
   298  	a.assertion()
   299  	if didFail, message := shouldNotContain(corpus, substring); didFail {
   300  		return a.fail(message, userMessageComponents...)
   301  	}
   302  	return true
   303  }
   304  
   305  // HasPrefix asserts that a corpus has a given prefix.
   306  func (a *Assertions) HasPrefix(corpus, prefix string, userMessageComponents ...interface{}) bool {
   307  	a.assertion()
   308  	if didFail, message := shouldHasPrefix(corpus, prefix); didFail {
   309  		return a.fail(message, userMessageComponents...)
   310  	}
   311  	return true
   312  }
   313  
   314  // NotHasPrefix asserts that a corpus does not have a given prefix.
   315  func (a *Assertions) NotHasPrefix(corpus, prefix string, userMessageComponents ...interface{}) bool {
   316  	a.assertion()
   317  	if didFail, message := shouldNotHasPrefix(corpus, prefix); didFail {
   318  		return a.fail(message, userMessageComponents...)
   319  	}
   320  	return true
   321  }
   322  
   323  // HasSuffix asserts that a corpus has a given suffix.
   324  func (a *Assertions) HasSuffix(corpus, suffix string, userMessageComponents ...interface{}) bool {
   325  	a.assertion()
   326  	if didFail, message := shouldHasSuffix(corpus, suffix); didFail {
   327  		return a.fail(message, userMessageComponents...)
   328  	}
   329  	return true
   330  }
   331  
   332  // NotHasSuffix asserts that a corpus does not have a given suffix.
   333  func (a *Assertions) NotHasSuffix(corpus, suffix string, userMessageComponents ...interface{}) bool {
   334  	a.assertion()
   335  	if didFail, message := shouldNotHasSuffix(corpus, suffix); didFail {
   336  		return a.fail(message, userMessageComponents...)
   337  	}
   338  	return true
   339  }
   340  
   341  // Matches returns if a given value matches a given regexp expression.
   342  func (a *Assertions) Matches(expr string, value interface{}, userMessageComponents ...interface{}) bool {
   343  	a.assertion()
   344  	if didFail, message := shouldMatch(expr, value); didFail {
   345  		return a.fail(message, userMessageComponents...)
   346  	}
   347  	return true
   348  }
   349  
   350  // NotMatches returns if a given value does not match a given regexp expression.
   351  func (a *Assertions) NotMatches(expr string, value interface{}, userMessageComponents ...interface{}) bool {
   352  	a.assertion()
   353  	if didFail, message := shouldNotMatch(expr, value); didFail {
   354  		return a.fail(message, userMessageComponents...)
   355  	}
   356  	return true
   357  }
   358  
   359  // Any applies a predicate.
   360  func (a *Assertions) Any(target interface{}, predicate Predicate, userMessageComponents ...interface{}) bool {
   361  	a.assertion()
   362  	if didFail, message := shouldAny(target, predicate); didFail {
   363  		return a.fail(message, userMessageComponents...)
   364  	}
   365  	return true
   366  }
   367  
   368  // AnyOfInt applies a predicate.
   369  func (a *Assertions) AnyOfInt(target []int, predicate PredicateOfInt, userMessageComponents ...interface{}) bool {
   370  	a.assertion()
   371  	if didFail, message := shouldAnyOfInt(target, predicate); didFail {
   372  		return a.fail(message, userMessageComponents...)
   373  	}
   374  	return true
   375  }
   376  
   377  // AnyOfFloat64 applies a predicate.
   378  func (a *Assertions) AnyOfFloat64(target []float64, predicate PredicateOfFloat, userMessageComponents ...interface{}) bool {
   379  	a.assertion()
   380  	if didFail, message := shouldAnyOfFloat(target, predicate); didFail {
   381  		return a.fail(message, userMessageComponents...)
   382  	}
   383  	return true
   384  }
   385  
   386  // AnyOfString applies a predicate.
   387  func (a *Assertions) AnyOfString(target []string, predicate PredicateOfString, userMessageComponents ...interface{}) bool {
   388  	a.assertion()
   389  	if didFail, message := shouldAnyOfString(target, predicate); didFail {
   390  		return a.fail(message, userMessageComponents...)
   391  	}
   392  	return true
   393  }
   394  
   395  // AnyCount applies a predicate and passes if it fires a given number of times .
   396  func (a *Assertions) AnyCount(target interface{}, times int, predicate Predicate, userMessageComponents ...interface{}) bool {
   397  	a.assertion()
   398  	if didFail, message := shouldAnyCount(target, times, predicate); didFail {
   399  		return a.fail(message, userMessageComponents...)
   400  	}
   401  	return true
   402  }
   403  
   404  // All applies a predicate.
   405  func (a *Assertions) All(target interface{}, predicate Predicate, userMessageComponents ...interface{}) bool {
   406  	a.assertion()
   407  	if didFail, message := shouldAll(target, predicate); didFail {
   408  		return a.fail(message, userMessageComponents...)
   409  	}
   410  	return true
   411  }
   412  
   413  // AllOfInt applies a predicate.
   414  func (a *Assertions) AllOfInt(target []int, predicate PredicateOfInt, userMessageComponents ...interface{}) bool {
   415  	a.assertion()
   416  	if didFail, message := shouldAllOfInt(target, predicate); didFail {
   417  		return a.fail(message, userMessageComponents...)
   418  	}
   419  	return true
   420  }
   421  
   422  // AllOfFloat64 applies a predicate.
   423  func (a *Assertions) AllOfFloat64(target []float64, predicate PredicateOfFloat, userMessageComponents ...interface{}) bool {
   424  	a.assertion()
   425  	if didFail, message := shouldAllOfFloat(target, predicate); didFail {
   426  		return a.fail(message, userMessageComponents...)
   427  	}
   428  	return true
   429  }
   430  
   431  // AllOfString applies a predicate.
   432  func (a *Assertions) AllOfString(target []string, predicate PredicateOfString, userMessageComponents ...interface{}) bool {
   433  	a.assertion()
   434  	if didFail, message := shouldAllOfString(target, predicate); didFail {
   435  		return a.fail(message, userMessageComponents...)
   436  	}
   437  	return true
   438  }
   439  
   440  // None applies a predicate.
   441  func (a *Assertions) None(target interface{}, predicate Predicate, userMessageComponents ...interface{}) bool {
   442  	a.assertion()
   443  	if didFail, message := shouldNone(target, predicate); didFail {
   444  		return a.fail(message, userMessageComponents...)
   445  	}
   446  	return true
   447  }
   448  
   449  // NoneOfInt applies a predicate.
   450  func (a *Assertions) NoneOfInt(target []int, predicate PredicateOfInt, userMessageComponents ...interface{}) bool {
   451  	a.assertion()
   452  	if didFail, message := shouldNoneOfInt(target, predicate); didFail {
   453  		return a.fail(message, userMessageComponents...)
   454  	}
   455  	return true
   456  }
   457  
   458  // NoneOfFloat64 applies a predicate.
   459  func (a *Assertions) NoneOfFloat64(target []float64, predicate PredicateOfFloat, userMessageComponents ...interface{}) bool {
   460  	a.assertion()
   461  	if didFail, message := shouldNoneOfFloat(target, predicate); didFail {
   462  		return a.fail(message, userMessageComponents...)
   463  	}
   464  	return true
   465  }
   466  
   467  // NoneOfString applies a predicate.
   468  func (a *Assertions) NoneOfString(target []string, predicate PredicateOfString, userMessageComponents ...interface{}) bool {
   469  	a.assertion()
   470  	if didFail, message := shouldNoneOfString(target, predicate); didFail {
   471  		return a.fail(message, userMessageComponents...)
   472  	}
   473  	return true
   474  }
   475  
   476  // FailNow forces a test failure (useful for debugging).
   477  func (a *Assertions) FailNow(userMessageComponents ...interface{}) {
   478  	failNow(a.Output, a.T, a.OutputFormat, NewFailure("Fatal Assertion Failed", userMessageComponents...))
   479  }
   480  
   481  // Fail forces a test failure (useful for debugging).
   482  func (a *Assertions) Fail(userMessageComponents ...interface{}) bool {
   483  	fail(a.Output, a.T, a.OutputFormat, NewFailure("Fatal Assertion Failed", userMessageComponents...))
   484  	return true
   485  }
   486  
   487  // --------------------------------------------------------------------------------
   488  // OUTPUT
   489  // --------------------------------------------------------------------------------
   490  
   491  func failNow(w io.Writer, t *testing.T, outputFormat OutputFormat, failure Failure) {
   492  	fail(w, t, outputFormat, failure)
   493  	if t != nil {
   494  		t.FailNow()
   495  	} else {
   496  		panic(failure)
   497  	}
   498  }
   499  
   500  func fail(w io.Writer, t *testing.T, outputFormat OutputFormat, failure Failure) {
   501  	var output string
   502  	switch outputFormat {
   503  	case OutputFormatDefault, OutputFormatText:
   504  		output = fmt.Sprintf("\r%s", getClearString())
   505  		output += failure.Text()
   506  	case OutputFormatJSON:
   507  		output = fmt.Sprintf("\r%s", getLocationString())
   508  		output += failure.JSON()
   509  	case OutputFormatUnitTest:
   510  		output = failure.TestString()
   511  	default:
   512  		panic(fmt.Errorf("invalid output format: %s", outputFormat))
   513  	}
   514  	if t != nil {
   515  		t.Error(output)
   516  	}
   517  	if w != nil {
   518  		fmt.Fprint(w, output)
   519  	}
   520  }
   521  
   522  func callerInfoStrings(frames []stackFrame) []string {
   523  	output := make([]string, len(frames))
   524  	for index := range frames {
   525  		output[index] = frames[index].String()
   526  	}
   527  	return output
   528  }
   529  
   530  type stackFrame struct {
   531  	PC       uintptr
   532  	FileFull string
   533  	Dir      string
   534  	File     string
   535  	Name     string
   536  	Line     int
   537  	OK       bool
   538  }
   539  
   540  func (sf stackFrame) String() string {
   541  	return fmt.Sprintf("%s:%d", sf.File, sf.Line)
   542  }
   543  
   544  func callerInfo() []stackFrame {
   545  	var name string
   546  	var callers []stackFrame
   547  	for i := 0; ; i++ {
   548  		var frame stackFrame
   549  		frame.PC, frame.FileFull, frame.Line, frame.OK = runtime.Caller(i)
   550  		if !frame.OK {
   551  			break
   552  		}
   553  
   554  		if frame.FileFull == "<autogenerated>" {
   555  			break
   556  		}
   557  
   558  		parts := strings.Split(frame.FileFull, "/")
   559  		frame.Dir = parts[len(parts)-2]
   560  		frame.File = parts[len(parts)-1]
   561  		if frame.Dir != "assert" {
   562  			callers = append(callers, frame)
   563  		}
   564  
   565  		f := runtime.FuncForPC(frame.PC)
   566  		if f == nil {
   567  			break
   568  		}
   569  		name = f.Name()
   570  
   571  		// Drop the package
   572  		segments := strings.Split(name, ".")
   573  		name = segments[len(segments)-1]
   574  		if isTest(name, "Test") ||
   575  			isTest(name, "Benchmark") ||
   576  			isTest(name, "Example") {
   577  			break
   578  		}
   579  	}
   580  
   581  	return callers
   582  }
   583  
   584  func color(input string, colorCode string) string {
   585  	return fmt.Sprintf("\033[%s;01m%s\033[0m", colorCode, input)
   586  }
   587  
   588  func isTest(name, prefix string) bool {
   589  	if !strings.HasPrefix(name, prefix) {
   590  		return false
   591  	}
   592  	if len(name) == len(prefix) { // "Test" is ok
   593  		return true
   594  	}
   595  	rune, _ := utf8.DecodeRuneInString(name[len(prefix):])
   596  	return !unicode.IsLower(rune)
   597  }
   598  
   599  func getClearString() string {
   600  	_, file, line, ok := runtime.Caller(1)
   601  	if !ok {
   602  		return ""
   603  	}
   604  	parts := strings.Split(file, "/")
   605  	file = parts[len(parts)-1]
   606  
   607  	return strings.Repeat(" ", len(fmt.Sprintf("%s:%d:      ", file, line))+2)
   608  }
   609  
   610  func getLocationString() string {
   611  	callers := callerInfo()
   612  	if len(callers) == 0 {
   613  		return ""
   614  	}
   615  	last := callers[len(callers)-1]
   616  	return fmt.Sprintf("%s:%d:      ", last.File, last.Line)
   617  }
   618  
   619  func safeExec(action func()) (err error) {
   620  	defer func() {
   621  		if r := recover(); r != nil {
   622  			err = fmt.Errorf("%v", r)
   623  		}
   624  	}()
   625  	action()
   626  	return
   627  }
   628  
   629  func randomString(length int) string {
   630  	const charset = "abcdefghijklmnopqrstuvwxyz"
   631  	b := make([]byte, length)
   632  	for i := range b {
   633  		b[i] = charset[rand.Intn(len(charset))]
   634  	}
   635  	return string(b)
   636  }
   637  
   638  // --------------------------------------------------------------------------------
   639  // ASSERTION LOGIC
   640  // --------------------------------------------------------------------------------
   641  
   642  func chanShouldHaveNonzeroCapacity(ch interface{}) (bool, string) {
   643  	v := reflect.ValueOf(ch)
   644  	if v.Kind() != reflect.Chan {
   645  		message := "Should be a channel"
   646  		return true, message
   647  	}
   648  	if v.Cap() == 0 {
   649  		message := "Should not have capacity 0"
   650  		return true, message
   651  	}
   652  	return false, ""
   653  }
   654  
   655  func shouldHaveLength(collection interface{}, length int) (bool, string) {
   656  	if l := getLength(collection); l != length {
   657  		message := shouldBeMultipleMessage(length, l, "Collection should have length")
   658  		return true, message
   659  	}
   660  	return false, ""
   661  }
   662  
   663  func shouldNotBeEmpty(collection interface{}) (bool, string) {
   664  	if l := getLength(collection); l == 0 {
   665  		message := "Should not be empty"
   666  		return true, message
   667  	}
   668  	return false, ""
   669  }
   670  
   671  func shouldBeEmpty(collection interface{}) (bool, string) {
   672  	if l := getLength(collection); l != 0 {
   673  		message := shouldBeMessage(collection, "Should be empty")
   674  		return true, message
   675  	}
   676  	return false, ""
   677  }
   678  
   679  func shouldBeEqual(expected, actual interface{}) (bool, string) {
   680  	if !areEqual(expected, actual) {
   681  		return true, equalMessage(expected, actual)
   682  	}
   683  	return false, ""
   684  }
   685  
   686  func shouldBeReferenceEqual(expected, actual interface{}) (bool, string) {
   687  	if !areReferenceEqual(expected, actual) {
   688  		return true, referenceEqualMessage(expected, actual)
   689  	}
   690  	return false, ""
   691  }
   692  
   693  func shouldNotPanic(action func()) (bool, string) {
   694  	var actual interface{}
   695  	var didPanic bool
   696  	func() {
   697  		defer func() {
   698  			actual = recover()
   699  			didPanic = actual != nil
   700  		}()
   701  		action()
   702  	}()
   703  
   704  	if didPanic {
   705  		return true, notPanicMessage(actual)
   706  	}
   707  	return false, ""
   708  }
   709  
   710  func shouldBePanicEqual(expected interface{}, action func()) (bool, string) {
   711  	var actual interface{}
   712  	var didPanic bool
   713  	func() {
   714  		defer func() {
   715  			actual = recover()
   716  			didPanic = actual != nil
   717  		}()
   718  		action()
   719  	}()
   720  
   721  	if !didPanic || (didPanic && !areEqual(expected, actual)) {
   722  		return true, panicEqualMessage(didPanic, expected, actual)
   723  	}
   724  	return false, ""
   725  }
   726  
   727  func shouldNotBeEqual(expected, actual interface{}) (bool, string) {
   728  	if areEqual(expected, actual) {
   729  		return true, notEqualMessage(expected, actual)
   730  	}
   731  	return false, ""
   732  }
   733  
   734  func shouldNotBeNil(object interface{}) (bool, string) {
   735  	if isNil(object) {
   736  		return true, "Should not be nil"
   737  	}
   738  	return false, ""
   739  }
   740  
   741  func shouldBeNil(object interface{}) (bool, string) {
   742  	if !isNil(object) {
   743  		return true, shouldBeMessage(object, "Should be nil")
   744  	}
   745  	return false, ""
   746  }
   747  
   748  func shouldBeTrue(value bool) (bool, string) {
   749  	if !value {
   750  		return true, "Should be true"
   751  	}
   752  	return false, ""
   753  }
   754  
   755  func shouldBeFalse(value bool) (bool, string) {
   756  	if value {
   757  		return true, "Should be false"
   758  	}
   759  	return false, ""
   760  }
   761  
   762  func shouldBeZero(value interface{}) (bool, string) {
   763  	if !isZero(value) {
   764  		return true, shouldBeMessage(value, "Should be zero")
   765  	}
   766  	return false, ""
   767  }
   768  
   769  func shouldBeNonZero(value interface{}) (bool, string) {
   770  	if isZero(value) {
   771  		return true, "Should be non-zero"
   772  	}
   773  	return false, ""
   774  }
   775  
   776  func shouldFileExist(filePath string) (bool, string) {
   777  	_, err := os.Stat(filePath)
   778  	if err != nil {
   779  		pwd, _ := os.Getwd()
   780  		message := fmt.Sprintf("File doesnt exist: %s, `pwd`: %s", filePath, pwd)
   781  		return true, message
   782  	}
   783  	return false, ""
   784  }
   785  
   786  func shouldBeInDelta(from, to, delta float64) (bool, string) {
   787  	diff := math.Abs(from - to)
   788  	if diff > delta {
   789  		message := fmt.Sprintf("Absolute difference of %0.5f and %0.5f should be less than %0.5f", from, to, delta)
   790  		return true, message
   791  	}
   792  	return false, ""
   793  }
   794  
   795  func shouldBeInTimeDelta(from, to time.Time, delta time.Duration) (bool, string) {
   796  	var diff time.Duration
   797  	if from.After(to) {
   798  		diff = from.Sub(to)
   799  	} else {
   800  		diff = to.Sub(from)
   801  	}
   802  	if diff > delta {
   803  		message := fmt.Sprintf("Delta of %s and %s should be less than %v", from.Format(time.RFC3339), to.Format(time.RFC3339), delta)
   804  		return true, message
   805  	}
   806  	return false, ""
   807  }
   808  
   809  func shouldNotBeInTimeDelta(from, to time.Time, delta time.Duration) (bool, string) {
   810  	var diff time.Duration
   811  	if from.After(to) {
   812  		diff = from.Sub(to)
   813  	} else {
   814  		diff = to.Sub(from)
   815  	}
   816  
   817  	if diff <= delta {
   818  		message := fmt.Sprintf("Delta of %s and %s should be greater than %v", from.Format(time.RFC3339), to.Format(time.RFC3339), delta)
   819  		return true, message
   820  	}
   821  	return false, ""
   822  }
   823  
   824  func shouldMatch(pattern string, value interface{}) (bool, string) {
   825  	matched, err := regexp.MatchString(pattern, fmt.Sprint(value))
   826  	if err != nil {
   827  		panic(err)
   828  	}
   829  	if !matched {
   830  		message := fmt.Sprintf("`%v` should match `%s`", value, pattern)
   831  		return true, message
   832  	}
   833  	return false, ""
   834  }
   835  
   836  func shouldNotMatch(pattern string, value interface{}) (bool, string) {
   837  	matched, err := regexp.MatchString(pattern, fmt.Sprint(value))
   838  	if err != nil {
   839  		panic(err)
   840  	}
   841  	if matched {
   842  		message := fmt.Sprintf("`%v` should not match `%s`", value, pattern)
   843  		return true, message
   844  	}
   845  	return false, ""
   846  }
   847  
   848  func shouldContain(corpus, subString string) (bool, string) {
   849  	if !strings.Contains(corpus, subString) {
   850  		message := fmt.Sprintf("`%s` should contain `%s`", corpus, subString)
   851  		return true, message
   852  	}
   853  	return false, ""
   854  }
   855  
   856  func shouldNotContain(corpus, subString string) (bool, string) {
   857  	if strings.Contains(corpus, subString) {
   858  		message := fmt.Sprintf("`%s` should not contain `%s`", corpus, subString)
   859  		return true, message
   860  	}
   861  	return false, ""
   862  }
   863  
   864  func shouldHasPrefix(corpus, prefix string) (bool, string) {
   865  	if !strings.HasPrefix(corpus, prefix) {
   866  		message := fmt.Sprintf("`%s` should have prefix `%s`", corpus, prefix)
   867  		return true, message
   868  	}
   869  	return false, ""
   870  }
   871  
   872  func shouldNotHasPrefix(corpus, prefix string) (bool, string) {
   873  	if strings.HasPrefix(corpus, prefix) {
   874  		message := fmt.Sprintf("`%s` should not have prefix `%s`", corpus, prefix)
   875  		return true, message
   876  	}
   877  	return false, ""
   878  }
   879  
   880  func shouldHasSuffix(corpus, suffix string) (bool, string) {
   881  	if !strings.HasSuffix(corpus, suffix) {
   882  		message := fmt.Sprintf("`%s` should have suffix `%s`", corpus, suffix)
   883  		return true, message
   884  	}
   885  	return false, ""
   886  }
   887  
   888  func shouldNotHasSuffix(corpus, suffix string) (bool, string) {
   889  	if strings.HasSuffix(corpus, suffix) {
   890  		message := fmt.Sprintf("`%s` should not have suffix `%s`", corpus, suffix)
   891  		return true, message
   892  	}
   893  	return false, ""
   894  }
   895  
   896  func shouldAny(target interface{}, predicate Predicate) (bool, string) {
   897  	t := reflect.TypeOf(target)
   898  	for t.Kind() == reflect.Ptr {
   899  		t = t.Elem()
   900  	}
   901  
   902  	v := reflect.ValueOf(target)
   903  	for v.Kind() == reflect.Ptr {
   904  		v = v.Elem()
   905  	}
   906  
   907  	if t.Kind() != reflect.Slice {
   908  		return true, "`target` is not a slice"
   909  	}
   910  
   911  	for x := 0; x < v.Len(); x++ {
   912  		obj := v.Index(x).Interface()
   913  		if predicate(obj) {
   914  			return false, ""
   915  		}
   916  	}
   917  	return true, "Predicate did not fire for any element in target"
   918  }
   919  
   920  func shouldAnyCount(target interface{}, times int, predicate Predicate) (bool, string) {
   921  	t := reflect.TypeOf(target)
   922  	for t.Kind() == reflect.Ptr {
   923  		t = t.Elem()
   924  	}
   925  
   926  	v := reflect.ValueOf(target)
   927  	for v.Kind() == reflect.Ptr {
   928  		v = v.Elem()
   929  	}
   930  
   931  	if t.Kind() != reflect.Slice {
   932  		return true, "`target` is not a slice"
   933  	}
   934  
   935  	var seen int
   936  	for x := 0; x < v.Len(); x++ {
   937  		obj := v.Index(x).Interface()
   938  		if predicate(obj) {
   939  			seen++
   940  		}
   941  	}
   942  	if seen != times {
   943  		return true, shouldBeMultipleMessage(times, seen, "Predicate should fire a given number of times")
   944  	}
   945  	return false, ""
   946  }
   947  
   948  func shouldAnyOfInt(target []int, predicate PredicateOfInt) (bool, string) {
   949  	v := reflect.ValueOf(target)
   950  
   951  	for x := 0; x < v.Len(); x++ {
   952  		obj := v.Index(x).Interface().(int)
   953  		if predicate(obj) {
   954  			return false, ""
   955  		}
   956  	}
   957  	return true, "Predicate did not fire for any element in target"
   958  }
   959  
   960  func shouldAnyOfFloat(target []float64, predicate PredicateOfFloat) (bool, string) {
   961  	v := reflect.ValueOf(target)
   962  
   963  	for x := 0; x < v.Len(); x++ {
   964  		obj := v.Index(x).Interface().(float64)
   965  		if predicate(obj) {
   966  			return false, ""
   967  		}
   968  	}
   969  	return true, "Predicate did not fire for any element in target"
   970  }
   971  
   972  func shouldAnyOfString(target []string, predicate PredicateOfString) (bool, string) {
   973  	v := reflect.ValueOf(target)
   974  
   975  	for x := 0; x < v.Len(); x++ {
   976  		obj := v.Index(x).Interface().(string)
   977  		if predicate(obj) {
   978  			return false, ""
   979  		}
   980  	}
   981  	return true, "Predicate did not fire for any element in target"
   982  }
   983  
   984  func shouldAll(target interface{}, predicate Predicate) (bool, string) {
   985  	t := reflect.TypeOf(target)
   986  	for t.Kind() == reflect.Ptr {
   987  		t = t.Elem()
   988  	}
   989  
   990  	v := reflect.ValueOf(target)
   991  	for v.Kind() == reflect.Ptr {
   992  		v = v.Elem()
   993  	}
   994  
   995  	if t.Kind() != reflect.Slice {
   996  		return true, "`target` is not a slice"
   997  	}
   998  
   999  	for x := 0; x < v.Len(); x++ {
  1000  		obj := v.Index(x).Interface()
  1001  		if !predicate(obj) {
  1002  			return true, fmt.Sprintf("Predicate failed for element in target: %#v", obj)
  1003  		}
  1004  	}
  1005  	return false, ""
  1006  }
  1007  
  1008  func shouldAllOfInt(target []int, predicate PredicateOfInt) (bool, string) {
  1009  	v := reflect.ValueOf(target)
  1010  
  1011  	for x := 0; x < v.Len(); x++ {
  1012  		obj := v.Index(x).Interface().(int)
  1013  		if !predicate(obj) {
  1014  			return true, fmt.Sprintf("Predicate failed for element in target: %#v", obj)
  1015  		}
  1016  	}
  1017  	return false, ""
  1018  }
  1019  
  1020  func shouldAllOfFloat(target []float64, predicate PredicateOfFloat) (bool, string) {
  1021  	v := reflect.ValueOf(target)
  1022  
  1023  	for x := 0; x < v.Len(); x++ {
  1024  		obj := v.Index(x).Interface().(float64)
  1025  		if !predicate(obj) {
  1026  			return true, fmt.Sprintf("Predicate failed for element in target: %#v", obj)
  1027  		}
  1028  	}
  1029  	return false, ""
  1030  }
  1031  
  1032  func shouldAllOfString(target []string, predicate PredicateOfString) (bool, string) {
  1033  	v := reflect.ValueOf(target)
  1034  
  1035  	for x := 0; x < v.Len(); x++ {
  1036  		obj := v.Index(x).Interface().(string)
  1037  		if !predicate(obj) {
  1038  			return true, fmt.Sprintf("Predicate failed for element in target: %#v", obj)
  1039  		}
  1040  	}
  1041  	return false, ""
  1042  }
  1043  
  1044  func shouldNone(target interface{}, predicate Predicate) (bool, string) {
  1045  	t := reflect.TypeOf(target)
  1046  	for t.Kind() == reflect.Ptr {
  1047  		t = t.Elem()
  1048  	}
  1049  
  1050  	v := reflect.ValueOf(target)
  1051  	for v.Kind() == reflect.Ptr {
  1052  		v = v.Elem()
  1053  	}
  1054  
  1055  	if t.Kind() != reflect.Slice {
  1056  		return true, "`target` is not a slice"
  1057  	}
  1058  
  1059  	for x := 0; x < v.Len(); x++ {
  1060  		obj := v.Index(x).Interface()
  1061  		if predicate(obj) {
  1062  			return true, fmt.Sprintf("Predicate passed for element in target: %#v", obj)
  1063  		}
  1064  	}
  1065  	return false, ""
  1066  }
  1067  
  1068  func shouldNoneOfInt(target []int, predicate PredicateOfInt) (bool, string) {
  1069  	v := reflect.ValueOf(target)
  1070  
  1071  	for x := 0; x < v.Len(); x++ {
  1072  		obj := v.Index(x).Interface().(int)
  1073  		if predicate(obj) {
  1074  			return true, fmt.Sprintf("Predicate passed for element in target: %#v", obj)
  1075  		}
  1076  	}
  1077  	return false, ""
  1078  }
  1079  
  1080  func shouldNoneOfFloat(target []float64, predicate PredicateOfFloat) (bool, string) {
  1081  	v := reflect.ValueOf(target)
  1082  
  1083  	for x := 0; x < v.Len(); x++ {
  1084  		obj := v.Index(x).Interface().(float64)
  1085  		if predicate(obj) {
  1086  			return true, fmt.Sprintf("Predicate passed for element in target: %#v", obj)
  1087  		}
  1088  	}
  1089  	return false, ""
  1090  }
  1091  
  1092  func shouldNoneOfString(target []string, predicate PredicateOfString) (bool, string) {
  1093  	v := reflect.ValueOf(target)
  1094  
  1095  	for x := 0; x < v.Len(); x++ {
  1096  		obj := v.Index(x).Interface().(string)
  1097  		if predicate(obj) {
  1098  			return true, fmt.Sprintf("Predicate passed for element in target: %#v", obj)
  1099  		}
  1100  	}
  1101  	return false, ""
  1102  }
  1103  
  1104  // --------------------------------------------------------------------------------
  1105  // UTILITY
  1106  // --------------------------------------------------------------------------------
  1107  
  1108  func shouldBeMultipleMessage(expected, actual interface{}, message string) string {
  1109  	expectedLabel := color("Expected", WHITE)
  1110  	actualLabel := color("Actual", WHITE)
  1111  
  1112  	return fmt.Sprintf(`%s
  1113  	%s: 	%#v
  1114  	%s: 	%#v`, message, expectedLabel, expected, actualLabel, actual)
  1115  }
  1116  
  1117  func shouldBeMessage(object interface{}, message string) string {
  1118  	actualLabel := color("Actual", WHITE)
  1119  	if err, ok := object.(error); ok {
  1120  		return fmt.Sprintf(`%s
  1121  	%s: 	%+v`, message, actualLabel, err)
  1122  	}
  1123  	return fmt.Sprintf(`%s
  1124  	%s: 	%#v`, message, actualLabel, object)
  1125  }
  1126  
  1127  func notEqualMessage(expected, actual interface{}) string {
  1128  	return shouldBeMultipleMessage(expected, actual, "Objects should not be equal")
  1129  }
  1130  
  1131  func equalMessage(expected, actual interface{}) string {
  1132  	return shouldBeMultipleMessage(expected, actual, "Objects should be equal")
  1133  }
  1134  
  1135  func referenceEqualMessage(expected, actual interface{}) string {
  1136  	return shouldBeMultipleMessage(expected, actual, "References should be equal")
  1137  }
  1138  
  1139  func panicEqualMessage(didPanic bool, expected, actual interface{}) string {
  1140  	if !didPanic {
  1141  		return "Should have produced a panic"
  1142  	}
  1143  	return shouldBeMultipleMessage(expected, actual, "Panic from action should equal")
  1144  }
  1145  
  1146  func notPanicMessage(actual interface{}) string {
  1147  	return shouldBeMessage(actual, "Should not have panicked")
  1148  }
  1149  
  1150  func getLength(object interface{}) int {
  1151  	if object == nil {
  1152  		return 0
  1153  	} else if object == "" {
  1154  		return 0
  1155  	}
  1156  
  1157  	objValue := reflect.ValueOf(object)
  1158  
  1159  	switch objValue.Kind() {
  1160  	case reflect.Map:
  1161  		fallthrough
  1162  	case reflect.Slice, reflect.Chan, reflect.String:
  1163  		{
  1164  			return objValue.Len()
  1165  		}
  1166  	}
  1167  	return 0
  1168  }
  1169  
  1170  func isNil(object interface{}) bool {
  1171  	if object == nil {
  1172  		return true
  1173  	}
  1174  
  1175  	value := reflect.ValueOf(object)
  1176  	kind := value.Kind()
  1177  	if kind >= reflect.Chan && kind <= reflect.Slice && value.IsNil() {
  1178  		return true
  1179  	}
  1180  	return false
  1181  }
  1182  
  1183  func isZero(value interface{}) bool {
  1184  	return areEqual(0, value)
  1185  }
  1186  
  1187  func areReferenceEqual(expected, actual interface{}) bool {
  1188  	if expected == nil && actual == nil {
  1189  		return true
  1190  	}
  1191  	if (expected == nil && actual != nil) || (expected != nil && actual == nil) {
  1192  		return false
  1193  	}
  1194  
  1195  	return expected == actual
  1196  }
  1197  
  1198  func areEqual(expected, actual interface{}) bool {
  1199  	if expected == nil && actual == nil {
  1200  		return true
  1201  	}
  1202  	if (expected == nil && actual != nil) || (expected != nil && actual == nil) {
  1203  		return false
  1204  	}
  1205  
  1206  	actualType := reflect.TypeOf(actual)
  1207  	if actualType == nil {
  1208  		return false
  1209  	}
  1210  	expectedValue := reflect.ValueOf(expected)
  1211  	if expectedValue.IsValid() && expectedValue.Type().ConvertibleTo(actualType) {
  1212  		return reflect.DeepEqual(expectedValue.Convert(actualType).Interface(), actual)
  1213  	}
  1214  
  1215  	return reflect.DeepEqual(expected, actual)
  1216  }