github.com/prysmaticlabs/prysm@v1.4.4/shared/testutil/assertions/assertions.go (about)

     1  package assertions
     2  
     3  import (
     4  	"fmt"
     5  	"path/filepath"
     6  	"reflect"
     7  	"runtime"
     8  	"strings"
     9  
    10  	"github.com/d4l3k/messagediff"
    11  	"github.com/prysmaticlabs/prysm/shared/sszutil"
    12  	"github.com/sirupsen/logrus/hooks/test"
    13  	"google.golang.org/protobuf/proto"
    14  )
    15  
    16  // AssertionTestingTB exposes enough testing.TB methods for assertions.
    17  type AssertionTestingTB interface {
    18  	Errorf(format string, args ...interface{})
    19  	Fatalf(format string, args ...interface{})
    20  }
    21  
    22  type assertionLoggerFn func(string, ...interface{})
    23  
    24  // Equal compares values using comparison operator.
    25  func Equal(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
    26  	if expected != actual {
    27  		errMsg := parseMsg("Values are not equal", msg...)
    28  		_, file, line, _ := runtime.Caller(2)
    29  		loggerFn("%s:%d %s, want: %[4]v (%[4]T), got: %[5]v (%[5]T)", filepath.Base(file), line, errMsg, expected, actual)
    30  	}
    31  }
    32  
    33  // NotEqual compares values using comparison operator.
    34  func NotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
    35  	if expected == actual {
    36  		errMsg := parseMsg("Values are equal", msg...)
    37  		_, file, line, _ := runtime.Caller(2)
    38  		loggerFn("%s:%d %s, both values are equal: %[4]v (%[4]T)", filepath.Base(file), line, errMsg, expected)
    39  	}
    40  }
    41  
    42  // DeepEqual compares values using DeepEqual.
    43  func DeepEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
    44  	if !isDeepEqual(expected, actual) {
    45  		errMsg := parseMsg("Values are not equal", msg...)
    46  		_, file, line, _ := runtime.Caller(2)
    47  		diff, _ := messagediff.PrettyDiff(expected, actual)
    48  		loggerFn("%s:%d %s, want: %#v, got: %#v, diff: %s", filepath.Base(file), line, errMsg, expected, actual, diff)
    49  	}
    50  }
    51  
    52  // DeepNotEqual compares values using DeepEqual.
    53  func DeepNotEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
    54  	if isDeepEqual(expected, actual) {
    55  		errMsg := parseMsg("Values are equal", msg...)
    56  		_, file, line, _ := runtime.Caller(2)
    57  		loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)
    58  	}
    59  }
    60  
    61  // DeepSSZEqual compares values using sszutil.DeepEqual.
    62  func DeepSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
    63  	if !sszutil.DeepEqual(expected, actual) {
    64  		errMsg := parseMsg("Values are not equal", msg...)
    65  		_, file, line, _ := runtime.Caller(2)
    66  		diff, _ := messagediff.PrettyDiff(expected, actual)
    67  		loggerFn("%s:%d %s, want: %#v, got: %#v, diff: %s", filepath.Base(file), line, errMsg, expected, actual, diff)
    68  	}
    69  }
    70  
    71  // DeepNotSSZEqual compares values using sszutil.DeepEqual.
    72  func DeepNotSSZEqual(loggerFn assertionLoggerFn, expected, actual interface{}, msg ...interface{}) {
    73  	if sszutil.DeepEqual(expected, actual) {
    74  		errMsg := parseMsg("Values are equal", msg...)
    75  		_, file, line, _ := runtime.Caller(2)
    76  		loggerFn("%s:%d %s, want: %#v, got: %#v", filepath.Base(file), line, errMsg, expected, actual)
    77  	}
    78  }
    79  
    80  // NoError asserts that error is nil.
    81  func NoError(loggerFn assertionLoggerFn, err error, msg ...interface{}) {
    82  	if err != nil {
    83  		errMsg := parseMsg("Unexpected error", msg...)
    84  		_, file, line, _ := runtime.Caller(2)
    85  		loggerFn("%s:%d %s: %v", filepath.Base(file), line, errMsg, err)
    86  	}
    87  }
    88  
    89  // ErrorContains asserts that actual error contains wanted message.
    90  func ErrorContains(loggerFn assertionLoggerFn, want string, err error, msg ...interface{}) {
    91  	if err == nil || !strings.Contains(err.Error(), want) {
    92  		errMsg := parseMsg("Expected error not returned", msg...)
    93  		_, file, line, _ := runtime.Caller(2)
    94  		loggerFn("%s:%d %s, got: %v, want: %s", filepath.Base(file), line, errMsg, err, want)
    95  	}
    96  }
    97  
    98  // NotNil asserts that passed value is not nil.
    99  func NotNil(loggerFn assertionLoggerFn, obj interface{}, msg ...interface{}) {
   100  	if isNil(obj) {
   101  		errMsg := parseMsg("Unexpected nil value", msg...)
   102  		_, file, line, _ := runtime.Caller(2)
   103  		loggerFn("%s:%d %s", filepath.Base(file), line, errMsg)
   104  	}
   105  }
   106  
   107  // isNil checks that underlying value of obj is nil.
   108  func isNil(obj interface{}) bool {
   109  	if obj == nil {
   110  		return true
   111  	}
   112  	value := reflect.ValueOf(obj)
   113  	switch value.Kind() {
   114  	case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice, reflect.UnsafePointer:
   115  		return value.IsNil()
   116  	}
   117  	return false
   118  }
   119  
   120  // LogsContain checks whether a given substring is a part of logs. If flag=false, inverse is checked.
   121  func LogsContain(loggerFn assertionLoggerFn, hook *test.Hook, want string, flag bool, msg ...interface{}) {
   122  	_, file, line, _ := runtime.Caller(2)
   123  	entries := hook.AllEntries()
   124  	var logs []string
   125  	match := false
   126  	for _, e := range entries {
   127  		msg, err := e.String()
   128  		if err != nil {
   129  			loggerFn("%s:%d Failed to format log entry to string: %v", filepath.Base(file), line, err)
   130  			return
   131  		}
   132  		if strings.Contains(msg, want) {
   133  			match = true
   134  		}
   135  		for _, field := range e.Data {
   136  			fieldStr, ok := field.(string)
   137  			if !ok {
   138  				continue
   139  			}
   140  			if strings.Contains(fieldStr, want) {
   141  				match = true
   142  			}
   143  		}
   144  		logs = append(logs, msg)
   145  	}
   146  	var errMsg string
   147  	if flag && !match {
   148  		errMsg = parseMsg("Expected log not found", msg...)
   149  	} else if !flag && match {
   150  		errMsg = parseMsg("Unexpected log found", msg...)
   151  	}
   152  	if errMsg != "" {
   153  		loggerFn("%s:%d %s: %v\nSearched logs:\n%v", filepath.Base(file), line, errMsg, want, logs)
   154  	}
   155  }
   156  
   157  func parseMsg(defaultMsg string, msg ...interface{}) string {
   158  	if len(msg) >= 1 {
   159  		msgFormat, ok := msg[0].(string)
   160  		if !ok {
   161  			return defaultMsg
   162  		}
   163  		return fmt.Sprintf(msgFormat, msg[1:]...)
   164  	}
   165  	return defaultMsg
   166  }
   167  
   168  func isDeepEqual(expected, actual interface{}) bool {
   169  	_, isProto := expected.(proto.Message)
   170  	if isProto {
   171  		return proto.Equal(expected.(proto.Message), actual.(proto.Message))
   172  	}
   173  	return reflect.DeepEqual(expected, actual)
   174  }
   175  
   176  // TBMock exposes enough testing.TB methods for assertions.
   177  type TBMock struct {
   178  	ErrorfMsg string
   179  	FatalfMsg string
   180  }
   181  
   182  // Errorf writes testing logs to ErrorfMsg.
   183  func (tb *TBMock) Errorf(format string, args ...interface{}) {
   184  	tb.ErrorfMsg = fmt.Sprintf(format, args...)
   185  }
   186  
   187  // Fatalf writes testing logs to FatalfMsg.
   188  func (tb *TBMock) Fatalf(format string, args ...interface{}) {
   189  	tb.FatalfMsg = fmt.Sprintf(format, args...)
   190  }