gopkg.in/rethinkdb/rethinkdb-go.v6@v6.2.2/internal/compare/compare.go (about) 1 package compare 2 3 import ( 4 "fmt" 5 "math" 6 "reflect" 7 "regexp" 8 "strconv" 9 "testing" 10 11 "github.com/stretchr/testify/assert" 12 ) 13 14 var AnythingIsFine = "reql_test.AnythingIsFine" 15 16 func Assert(t *testing.T, expected, actual interface{}) { 17 expectedVal := expected 18 if e, ok := expected.(Expected); ok { 19 expectedVal = e.Val 20 } 21 22 ok, msg := Compare(expected, actual) 23 if !ok { 24 assert.Fail(t, fmt.Sprintf("Not equal: %#v (expected)\n != %#v (actual)", expectedVal, actual), msg) 25 } 26 } 27 28 func AssertFalse(t *testing.T, expected, actual interface{}) { 29 expectedVal := expected 30 if e, ok := expected.(Expected); ok { 31 expectedVal = e.Val 32 } 33 34 ok, msg := Compare(expected, actual) 35 if ok { 36 assert.Fail(t, fmt.Sprintf("Should not be equal: %#v (expected)\n == %#v (actual)", expectedVal, actual), msg) 37 } 38 } 39 40 func AssertPrecision(t *testing.T, expected, actual interface{}, precision float64) { 41 expectedVal := expected 42 if e, ok := expected.(Expected); ok { 43 expectedVal = e.Val 44 } 45 46 ok, msg := ComparePrecision(expected, actual, precision) 47 if !ok { 48 assert.Fail(t, fmt.Sprintf("Not equal: %#v (expected)\n != %#v (actual)", expectedVal, actual), msg) 49 } 50 } 51 52 func AssertPrecisionFalse(t *testing.T, expected, actual interface{}, precision float64) { 53 expectedVal := expected 54 if e, ok := expected.(Expected); ok { 55 expectedVal = e.Val 56 } 57 58 ok, msg := ComparePrecision(expected, actual, precision) 59 if ok { 60 assert.Fail(t, fmt.Sprintf("Should not be equal: %#v (expected)\n == %#v (actual)", expectedVal, actual), msg) 61 } 62 } 63 64 func Compare(expected, actual interface{}) (bool, string) { 65 return ComparePrecision(expected, actual, 0.00000000001) 66 } 67 68 func ComparePrecision(expected, actual interface{}, precision float64) (bool, string) { 69 return compare(expected, actual, true, false, precision) 70 } 71 72 func compare(expected, actual interface{}, ordered, partial bool, precision float64) (bool, string) { 73 if e, ok := expected.(Expected); ok { 74 partial = e.Partial 75 ordered = e.Ordered 76 expected = e.Val 77 } 78 79 // Anything 80 if expected == AnythingIsFine { 81 return true, "" 82 } 83 84 expectedVal := reflect.ValueOf(expected) 85 actualVal := reflect.ValueOf(actual) 86 87 // Nil 88 if expected == nil { 89 switch actualVal.Kind() { 90 case reflect.Bool: 91 expected = false 92 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 93 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, 94 reflect.Float32, reflect.Float64: 95 expected = 0.0 96 case reflect.String: 97 expected = "" 98 } 99 100 if expected == actual { 101 return true, "" 102 } 103 } 104 105 // Regex 106 if expr, ok := expected.(Regex); ok { 107 re, err := regexp.Compile(string(expr)) 108 if err != nil { 109 return false, fmt.Sprintf("Failed to compile regexp: %s", err) 110 } 111 112 if actualVal.Kind() != reflect.String { 113 return false, fmt.Sprintf("Expected string, got %t (%T)", actual, actual) 114 } 115 116 if !re.MatchString(actualVal.String()) { 117 return false, fmt.Sprintf("Value %v did not match regexp '%s'", actual, expr) 118 } 119 120 return true, "" 121 } 122 123 switch expectedVal.Kind() { 124 125 // Bool 126 case reflect.Bool: 127 if expected == actual { 128 return true, "" 129 } 130 // Number 131 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 132 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, 133 reflect.Float32, reflect.Float64: 134 switch actualVal.Kind() { 135 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, 136 reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, 137 reflect.Float32, reflect.Float64, reflect.String: 138 diff := math.Abs(reflectNumber(expectedVal) - reflectNumber(actualVal)) 139 if diff <= precision { 140 return true, "" 141 } 142 143 if precision != 0 { 144 return false, fmt.Sprintf("Value %v was not within %f of %v", expected, precision, actual) 145 } 146 147 return false, fmt.Sprintf("Expected %v but got %v", expected, actual) 148 } 149 150 // String 151 case reflect.String: 152 actualStr := fmt.Sprintf("%v", actual) 153 if expected == actualStr { 154 return true, "" 155 } 156 // Struct 157 case reflect.Struct: 158 // Convert expected struct to map and compare with actual value 159 return compare(reflectMap(expectedVal), actual, ordered, partial, precision) 160 // Map 161 case reflect.Map: 162 switch actualVal.Kind() { 163 case reflect.Struct: 164 // Convert actual struct to map and compare with expected map 165 return compare(expected, reflectMap(actualVal), ordered, partial, precision) 166 case reflect.Map: 167 expectedKeys := expectedVal.MapKeys() 168 actualKeys := actualVal.MapKeys() 169 170 for _, expectedKey := range expectedKeys { 171 keyFound := false 172 for _, actualKey := range actualKeys { 173 if ok, _ := Compare(expectedKey.Interface(), actualKey.Interface()); ok { 174 keyFound = true 175 break 176 } 177 } 178 if !keyFound { 179 return false, fmt.Sprintf("Expected field %v but not found", expectedKey) 180 } 181 } 182 183 if !partial { 184 expectedKeyVals := reflectMapKeys(expectedKeys) 185 actualKeyVals := reflectMapKeys(actualKeys) 186 if ok, _ := compare(expectedKeyVals, actualKeyVals, false, false, 0.0); !ok { 187 return false, fmt.Sprintf( 188 "Unmatched keys from either side: expected fields %v, got %v", 189 expectedKeyVals, actualKeyVals, 190 ) 191 } 192 } 193 194 expectedMap := reflectMap(expectedVal) 195 actualMap := reflectMap(actualVal) 196 197 for k, v := range expectedMap { 198 if ok, reason := compare(v, actualMap[k], ordered, partial, precision); !ok { 199 return false, reason 200 } 201 } 202 203 return true, "" 204 default: 205 return false, fmt.Sprintf("Expected map, got %v (%T)", actual, actual) 206 } 207 // Slice/Array 208 case reflect.Slice, reflect.Array: 209 switch actualVal.Kind() { 210 case reflect.Slice, reflect.Array: 211 if ordered { 212 expectedArr := reflectSlice(expectedVal) 213 actualArr := reflectSlice(actualVal) 214 215 j := 0 216 for i := 0; i < len(expectedArr); i++ { 217 expectedArrVal := expectedArr[i] 218 for { 219 if j >= len(actualArr) { 220 return false, fmt.Sprintf("Ran out of results before finding %v", expectedArrVal) 221 } 222 223 actualArrVal := actualArr[j] 224 j++ 225 226 if ok, _ := compare(expectedArrVal, actualArrVal, ordered, partial, precision); ok { 227 break 228 } else if !partial { 229 return false, fmt.Sprintf("Unexpected item %v while looking for %v", actualArrVal, expectedArrVal) 230 } 231 } 232 } 233 if !partial && j < len(actualArr) { 234 return false, fmt.Sprintf("Unexpected extra results: %v", actualArr[j:]) 235 } 236 } else { 237 expectedArr := reflectSlice(expectedVal) 238 actualArr := reflectSlice(actualVal) 239 240 for _, expectedArrVal := range expectedArr { 241 found := false 242 for j, actualArrVal := range actualArr { 243 if ok, _ := compare(expectedArrVal, actualArrVal, ordered, partial, precision); ok { 244 found = true 245 actualArr = append(actualArr[:j], actualArr[j+1:]...) 246 break 247 } 248 } 249 if !found { 250 return false, fmt.Sprintf("Missing expected item %v", expectedArrVal) 251 } 252 } 253 254 if !partial && len(actualArr) > 0 { 255 return false, fmt.Sprintf("Extra items returned: %v", expectedArr) 256 } 257 } 258 259 return true, "" 260 } 261 // Other 262 default: 263 if expected == actual { 264 return true, "" 265 } 266 } 267 268 return false, fmt.Sprintf("Expected %v (%T) but got %v (%T)", expected, expected, actual, actual) 269 } 270 271 func reflectNumber(v reflect.Value) float64 { 272 switch v.Kind() { 273 case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: 274 return float64(v.Int()) 275 case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64: 276 return float64(v.Uint()) 277 case reflect.Float32, reflect.Float64: 278 return v.Float() 279 case reflect.String: 280 f, _ := strconv.ParseFloat(v.String(), 64) 281 return f 282 default: 283 return float64(0) 284 } 285 } 286 287 func reflectMap(v reflect.Value) map[interface{}]interface{} { 288 switch v.Kind() { 289 case reflect.Struct: 290 m := map[interface{}]interface{}{} 291 for i := 0; i < v.NumField(); i++ { 292 sf := v.Type().Field(i) 293 if sf.PkgPath != "" && !sf.Anonymous { 294 continue // unexported 295 } 296 297 k := sf.Name 298 v := v.Field(i).Interface() 299 300 m[k] = v 301 } 302 return m 303 case reflect.Map: 304 m := map[interface{}]interface{}{} 305 for _, mk := range v.MapKeys() { 306 k := "" 307 if mk.Interface() != nil { 308 k = fmt.Sprintf("%v", mk.Interface()) 309 } 310 v := v.MapIndex(mk).Interface() 311 312 m[k] = v 313 } 314 return m 315 default: 316 return nil 317 } 318 } 319 320 func reflectSlice(v reflect.Value) []interface{} { 321 switch v.Kind() { 322 case reflect.Slice, reflect.Array: 323 s := []interface{}{} 324 for i := 0; i < v.Len(); i++ { 325 s = append(s, v.Index(i).Interface()) 326 } 327 return s 328 default: 329 return nil 330 } 331 } 332 333 func reflectMapKeys(keys []reflect.Value) []interface{} { 334 s := []interface{}{} 335 for _, key := range keys { 336 s = append(s, key.Interface()) 337 } 338 return s 339 } 340 341 func reflectInterfaces(vals []reflect.Value) []interface{} { 342 ret := []interface{}{} 343 for _, val := range vals { 344 ret = append(ret, val.Interface()) 345 } 346 return ret 347 }