gopkg.in/rethinkdb/rethinkdb-go.v6@v6.2.2/internal/compare/compare.go (about)

     1  package compare
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"reflect"
     7  	"regexp"
     8  	"strconv"
     9  	"testing"
    10  
    11  	"github.com/stretchr/testify/assert"
    12  )
    13  
    14  var AnythingIsFine = "reql_test.AnythingIsFine"
    15  
    16  func Assert(t *testing.T, expected, actual interface{}) {
    17  	expectedVal := expected
    18  	if e, ok := expected.(Expected); ok {
    19  		expectedVal = e.Val
    20  	}
    21  
    22  	ok, msg := Compare(expected, actual)
    23  	if !ok {
    24  		assert.Fail(t, fmt.Sprintf("Not equal: %#v (expected)\n           != %#v (actual)", expectedVal, actual), msg)
    25  	}
    26  }
    27  
    28  func AssertFalse(t *testing.T, expected, actual interface{}) {
    29  	expectedVal := expected
    30  	if e, ok := expected.(Expected); ok {
    31  		expectedVal = e.Val
    32  	}
    33  
    34  	ok, msg := Compare(expected, actual)
    35  	if ok {
    36  		assert.Fail(t, fmt.Sprintf("Should not be equal: %#v (expected)\n           == %#v (actual)", expectedVal, actual), msg)
    37  	}
    38  }
    39  
    40  func AssertPrecision(t *testing.T, expected, actual interface{}, precision float64) {
    41  	expectedVal := expected
    42  	if e, ok := expected.(Expected); ok {
    43  		expectedVal = e.Val
    44  	}
    45  
    46  	ok, msg := ComparePrecision(expected, actual, precision)
    47  	if !ok {
    48  		assert.Fail(t, fmt.Sprintf("Not equal: %#v (expected)\n           != %#v (actual)", expectedVal, actual), msg)
    49  	}
    50  }
    51  
    52  func AssertPrecisionFalse(t *testing.T, expected, actual interface{}, precision float64) {
    53  	expectedVal := expected
    54  	if e, ok := expected.(Expected); ok {
    55  		expectedVal = e.Val
    56  	}
    57  
    58  	ok, msg := ComparePrecision(expected, actual, precision)
    59  	if ok {
    60  		assert.Fail(t, fmt.Sprintf("Should not be equal: %#v (expected)\n           == %#v (actual)", expectedVal, actual), msg)
    61  	}
    62  }
    63  
    64  func Compare(expected, actual interface{}) (bool, string) {
    65  	return ComparePrecision(expected, actual, 0.00000000001)
    66  }
    67  
    68  func ComparePrecision(expected, actual interface{}, precision float64) (bool, string) {
    69  	return compare(expected, actual, true, false, precision)
    70  }
    71  
    72  func compare(expected, actual interface{}, ordered, partial bool, precision float64) (bool, string) {
    73  	if e, ok := expected.(Expected); ok {
    74  		partial = e.Partial
    75  		ordered = e.Ordered
    76  		expected = e.Val
    77  	}
    78  
    79  	// Anything
    80  	if expected == AnythingIsFine {
    81  		return true, ""
    82  	}
    83  
    84  	expectedVal := reflect.ValueOf(expected)
    85  	actualVal := reflect.ValueOf(actual)
    86  
    87  	// Nil
    88  	if expected == nil {
    89  		switch actualVal.Kind() {
    90  		case reflect.Bool:
    91  			expected = false
    92  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
    93  			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
    94  			reflect.Float32, reflect.Float64:
    95  			expected = 0.0
    96  		case reflect.String:
    97  			expected = ""
    98  		}
    99  
   100  		if expected == actual {
   101  			return true, ""
   102  		}
   103  	}
   104  
   105  	// Regex
   106  	if expr, ok := expected.(Regex); ok {
   107  		re, err := regexp.Compile(string(expr))
   108  		if err != nil {
   109  			return false, fmt.Sprintf("Failed to compile regexp: %s", err)
   110  		}
   111  
   112  		if actualVal.Kind() != reflect.String {
   113  			return false, fmt.Sprintf("Expected string, got %t (%T)", actual, actual)
   114  		}
   115  
   116  		if !re.MatchString(actualVal.String()) {
   117  			return false, fmt.Sprintf("Value %v did not match regexp '%s'", actual, expr)
   118  		}
   119  
   120  		return true, ""
   121  	}
   122  
   123  	switch expectedVal.Kind() {
   124  
   125  	// Bool
   126  	case reflect.Bool:
   127  		if expected == actual {
   128  			return true, ""
   129  		}
   130  	// Number
   131  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   132  		reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   133  		reflect.Float32, reflect.Float64:
   134  		switch actualVal.Kind() {
   135  		case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64,
   136  			reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64,
   137  			reflect.Float32, reflect.Float64, reflect.String:
   138  			diff := math.Abs(reflectNumber(expectedVal) - reflectNumber(actualVal))
   139  			if diff <= precision {
   140  				return true, ""
   141  			}
   142  
   143  			if precision != 0 {
   144  				return false, fmt.Sprintf("Value %v was not within %f of %v", expected, precision, actual)
   145  			}
   146  
   147  			return false, fmt.Sprintf("Expected %v but got %v", expected, actual)
   148  		}
   149  
   150  	// String
   151  	case reflect.String:
   152  		actualStr := fmt.Sprintf("%v", actual)
   153  		if expected == actualStr {
   154  			return true, ""
   155  		}
   156  	// Struct
   157  	case reflect.Struct:
   158  		// Convert expected struct to map and compare with actual value
   159  		return compare(reflectMap(expectedVal), actual, ordered, partial, precision)
   160  	// Map
   161  	case reflect.Map:
   162  		switch actualVal.Kind() {
   163  		case reflect.Struct:
   164  			// Convert actual struct to map and compare with expected map
   165  			return compare(expected, reflectMap(actualVal), ordered, partial, precision)
   166  		case reflect.Map:
   167  			expectedKeys := expectedVal.MapKeys()
   168  			actualKeys := actualVal.MapKeys()
   169  
   170  			for _, expectedKey := range expectedKeys {
   171  				keyFound := false
   172  				for _, actualKey := range actualKeys {
   173  					if ok, _ := Compare(expectedKey.Interface(), actualKey.Interface()); ok {
   174  						keyFound = true
   175  						break
   176  					}
   177  				}
   178  				if !keyFound {
   179  					return false, fmt.Sprintf("Expected field %v but not found", expectedKey)
   180  				}
   181  			}
   182  
   183  			if !partial {
   184  				expectedKeyVals := reflectMapKeys(expectedKeys)
   185  				actualKeyVals := reflectMapKeys(actualKeys)
   186  				if ok, _ := compare(expectedKeyVals, actualKeyVals, false, false, 0.0); !ok {
   187  					return false, fmt.Sprintf(
   188  						"Unmatched keys from either side: expected fields %v, got %v",
   189  						expectedKeyVals, actualKeyVals,
   190  					)
   191  				}
   192  			}
   193  
   194  			expectedMap := reflectMap(expectedVal)
   195  			actualMap := reflectMap(actualVal)
   196  
   197  			for k, v := range expectedMap {
   198  				if ok, reason := compare(v, actualMap[k], ordered, partial, precision); !ok {
   199  					return false, reason
   200  				}
   201  			}
   202  
   203  			return true, ""
   204  		default:
   205  			return false, fmt.Sprintf("Expected map, got %v (%T)", actual, actual)
   206  		}
   207  	// Slice/Array
   208  	case reflect.Slice, reflect.Array:
   209  		switch actualVal.Kind() {
   210  		case reflect.Slice, reflect.Array:
   211  			if ordered {
   212  				expectedArr := reflectSlice(expectedVal)
   213  				actualArr := reflectSlice(actualVal)
   214  
   215  				j := 0
   216  				for i := 0; i < len(expectedArr); i++ {
   217  					expectedArrVal := expectedArr[i]
   218  					for {
   219  						if j >= len(actualArr) {
   220  							return false, fmt.Sprintf("Ran out of results before finding %v", expectedArrVal)
   221  						}
   222  
   223  						actualArrVal := actualArr[j]
   224  						j++
   225  
   226  						if ok, _ := compare(expectedArrVal, actualArrVal, ordered, partial, precision); ok {
   227  							break
   228  						} else if !partial {
   229  							return false, fmt.Sprintf("Unexpected item %v while looking for %v", actualArrVal, expectedArrVal)
   230  						}
   231  					}
   232  				}
   233  				if !partial && j < len(actualArr) {
   234  					return false, fmt.Sprintf("Unexpected extra results: %v", actualArr[j:])
   235  				}
   236  			} else {
   237  				expectedArr := reflectSlice(expectedVal)
   238  				actualArr := reflectSlice(actualVal)
   239  
   240  				for _, expectedArrVal := range expectedArr {
   241  					found := false
   242  					for j, actualArrVal := range actualArr {
   243  						if ok, _ := compare(expectedArrVal, actualArrVal, ordered, partial, precision); ok {
   244  							found = true
   245  							actualArr = append(actualArr[:j], actualArr[j+1:]...)
   246  							break
   247  						}
   248  					}
   249  					if !found {
   250  						return false, fmt.Sprintf("Missing expected item %v", expectedArrVal)
   251  					}
   252  				}
   253  
   254  				if !partial && len(actualArr) > 0 {
   255  					return false, fmt.Sprintf("Extra items returned: %v", expectedArr)
   256  				}
   257  			}
   258  
   259  			return true, ""
   260  		}
   261  	// Other
   262  	default:
   263  		if expected == actual {
   264  			return true, ""
   265  		}
   266  	}
   267  
   268  	return false, fmt.Sprintf("Expected %v (%T) but got %v (%T)", expected, expected, actual, actual)
   269  }
   270  
   271  func reflectNumber(v reflect.Value) float64 {
   272  	switch v.Kind() {
   273  	case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64:
   274  		return float64(v.Int())
   275  	case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64:
   276  		return float64(v.Uint())
   277  	case reflect.Float32, reflect.Float64:
   278  		return v.Float()
   279  	case reflect.String:
   280  		f, _ := strconv.ParseFloat(v.String(), 64)
   281  		return f
   282  	default:
   283  		return float64(0)
   284  	}
   285  }
   286  
   287  func reflectMap(v reflect.Value) map[interface{}]interface{} {
   288  	switch v.Kind() {
   289  	case reflect.Struct:
   290  		m := map[interface{}]interface{}{}
   291  		for i := 0; i < v.NumField(); i++ {
   292  			sf := v.Type().Field(i)
   293  			if sf.PkgPath != "" && !sf.Anonymous {
   294  				continue // unexported
   295  			}
   296  
   297  			k := sf.Name
   298  			v := v.Field(i).Interface()
   299  
   300  			m[k] = v
   301  		}
   302  		return m
   303  	case reflect.Map:
   304  		m := map[interface{}]interface{}{}
   305  		for _, mk := range v.MapKeys() {
   306  			k := ""
   307  			if mk.Interface() != nil {
   308  				k = fmt.Sprintf("%v", mk.Interface())
   309  			}
   310  			v := v.MapIndex(mk).Interface()
   311  
   312  			m[k] = v
   313  		}
   314  		return m
   315  	default:
   316  		return nil
   317  	}
   318  }
   319  
   320  func reflectSlice(v reflect.Value) []interface{} {
   321  	switch v.Kind() {
   322  	case reflect.Slice, reflect.Array:
   323  		s := []interface{}{}
   324  		for i := 0; i < v.Len(); i++ {
   325  			s = append(s, v.Index(i).Interface())
   326  		}
   327  		return s
   328  	default:
   329  		return nil
   330  	}
   331  }
   332  
   333  func reflectMapKeys(keys []reflect.Value) []interface{} {
   334  	s := []interface{}{}
   335  	for _, key := range keys {
   336  		s = append(s, key.Interface())
   337  	}
   338  	return s
   339  }
   340  
   341  func reflectInterfaces(vals []reflect.Value) []interface{} {
   342  	ret := []interface{}{}
   343  	for _, val := range vals {
   344  		ret = append(ret, val.Interface())
   345  	}
   346  	return ret
   347  }