github.com/hoveychen/protoreflect@v1.4.7-0.20221103114119-0b4b3385ec76/internal/testutil/asserts.go (about) 1 package testutil 2 3 import ( 4 "bytes" 5 "fmt" 6 "math" 7 "os" 8 "reflect" 9 "runtime" 10 "strings" 11 "testing" 12 ) 13 14 // Ceq is a custom equals check; the given function returns true if its arguments are equal 15 func Ceq(t *testing.T, expected, actual interface{}, eq func(a, b interface{}) bool, context ...interface{}) bool { 16 return ceq(getCaller(), t, expected, actual, eq, context) 17 } 18 19 func ceq(caller string, t *testing.T, expected, actual interface{}, eq func(a, b interface{}) bool, context []interface{}) bool { 20 e := eq(expected, actual) 21 require(caller, t, e, mergeContext(context, "Expecting %v (%v), got %v (%v)", expected, reflect.TypeOf(expected), actual, reflect.TypeOf(actual))) 22 return e 23 } 24 25 // Cneq is a custom not-equals check; the given function returns true if its arguments are equal 26 func Cneq(t *testing.T, unexpected, actual interface{}, eq func(a, b interface{}) bool, context ...interface{}) bool { 27 return cneq(getCaller(), t, unexpected, actual, eq, context) 28 } 29 30 func cneq(caller string, t *testing.T, unexpected, actual interface{}, eq func(a, b interface{}) bool, context []interface{}) bool { 31 ne := !eq(unexpected, actual) 32 require(caller, t, ne, mergeContext(context, "Value should not be %v (%v)", unexpected, reflect.TypeOf(unexpected))) 33 return ne 34 } 35 36 // Require is an assertion that logs a failure if its given argument is not true 37 func Require(t *testing.T, condition bool, context ...interface{}) { 38 require(getCaller(), t, condition, context) 39 } 40 41 func require(caller string, t *testing.T, condition bool, context []interface{}) { 42 if !condition { 43 if len(context) == 0 { 44 t.Fatalf("%s: Assertion failed", caller) 45 } else { 46 msg := context[0].(string) 47 // if any args were deferred (e.g. a function instead of a value), get those args now 48 args := make([]interface{}, len(context)-1) 49 for i, a := range context[1:] { 50 rv := reflect.ValueOf(a) 51 if rv.Kind() == reflect.Func { 52 a = rv.Call([]reflect.Value{})[0].Interface() 53 } 54 args[i] = a 55 } 56 t.Fatalf("%s: %s", caller, fmt.Sprintf(msg, args...)) 57 } 58 } 59 } 60 61 func mergeContext(context []interface{}, msg string, msgArgs ...interface{}) []interface{} { 62 if len(context) == 0 { 63 ret := make([]interface{}, len(msgArgs)+1) 64 ret[0] = msg 65 for i, a := range msgArgs { 66 ret[i+1] = a 67 } 68 return ret 69 } else { 70 ret := make([]interface{}, len(msgArgs)+2) 71 ret[0] = msg + ": %s" 72 for i, a := range msgArgs { 73 ret[i+1] = a 74 } 75 ret[len(ret)-1] = func() string { 76 f := context[0].(string) 77 return fmt.Sprintf(f, context[1:]...) 78 } 79 return ret 80 } 81 } 82 83 func getCaller() string { 84 pc, file, line, ok := runtime.Caller(2) 85 if !ok { 86 return "?" 87 } 88 fn := runtime.FuncForPC(pc) 89 var fnName string 90 if fn == nil { 91 fnName = "?" 92 } else { 93 fnName = fn.Name() 94 } 95 return fmt.Sprintf("%s(%s:%d)", lastComponents(fnName, 1), lastComponents(file, 2), line) 96 } 97 98 const pathSep = string(os.PathSeparator) 99 100 func lastComponents(s string, count int) string { 101 ss := s 102 var i int 103 for count > 0 { 104 i = strings.LastIndex(ss, pathSep) 105 if i < 0 { 106 return s 107 } 108 count-- 109 ss = ss[:i] 110 } 111 return s[i+1:] 112 } 113 114 // Ok asserts that the given error is nil 115 func Ok(t *testing.T, err error, context ...interface{}) { 116 require(getCaller(), t, err == nil, mergeContext(context, "Unexpected error: %s", func() interface{} { return err.Error() })) 117 } 118 119 // Nok asserts that the given error is not nil 120 func Nok(t *testing.T, err error, context ...interface{}) { 121 require(getCaller(), t, err != nil, mergeContext(context, "Expected error but got none")) 122 } 123 124 // Eq asserts that the given two values are equal 125 func Eq(t *testing.T, expected, actual interface{}, context ...interface{}) bool { 126 return ceq(getCaller(), t, expected, actual, eqany, context) 127 } 128 129 // Neq asserts that the given two values are not equal 130 func Neq(t *testing.T, unexpected, actual interface{}, context ...interface{}) bool { 131 return cneq(getCaller(), t, unexpected, actual, eqany, context) 132 } 133 134 // default equality test and helpers 135 136 func eqany(expected, actual interface{}) bool { 137 if expected == nil && actual == nil { 138 return true 139 } 140 if expected == nil || actual == nil { 141 return false 142 } 143 144 // We don't want reflect.DeepEquals because of its recursive nature. So we need 145 // a custom compare for slices and maps. Two slices are equal if they have the 146 // same number of elements and the elements at the same index are equal to each 147 // other. Two maps are equal if their key sets are the same and the corresponding 148 // values are equal. (The relationship is not recursive, slices or maps that 149 // contain other slices or maps can't be tested.) 150 et := reflect.TypeOf(expected) 151 152 if et.Kind() == reflect.Slice { 153 return eqslice(reflect.ValueOf(expected), reflect.ValueOf(actual)) 154 } else if et.Kind() == reflect.Map { 155 return eqmap(reflect.ValueOf(expected), reflect.ValueOf(actual)) 156 } else { 157 return eqscalar(expected, actual) 158 } 159 } 160 161 func eqscalar(expected, actual interface{}) bool { 162 // special-case simple equality for []byte (since slices aren't directly comparable) 163 if e, ok := expected.([]byte); ok { 164 a, ok := actual.([]byte) 165 return ok && bytes.Equal(e, a) 166 } 167 // and special-cases to handle NaN 168 if e, ok := expected.(float32); ok && math.IsNaN(float64(e)) { 169 a, ok := actual.(float32) 170 return ok && math.IsNaN(float64(a)) 171 } 172 if e, ok := expected.(float64); ok && math.IsNaN(e) { 173 a, ok := actual.(float64) 174 return ok && math.IsNaN(a) 175 } 176 // simple logic for everything else 177 return expected == actual 178 } 179 180 func eqslice(expected, actual reflect.Value) bool { 181 if expected.Len() != actual.Len() { 182 return false 183 } 184 for i := 0; i < expected.Len(); i++ { 185 e := expected.Index(i).Interface() 186 a := actual.Index(i).Interface() 187 if !eqscalar(e, a) { 188 return false 189 } 190 } 191 return true 192 } 193 194 func eqmap(expected, actual reflect.Value) bool { 195 if expected.Len() != actual.Len() { 196 return false 197 } 198 for _, k := range expected.MapKeys() { 199 e := expected.MapIndex(k) 200 a := actual.MapIndex(k) 201 if !a.IsValid() { 202 return false 203 } 204 if !eqscalar(e.Interface(), a.Interface()) { 205 return false 206 } 207 } 208 return true 209 }