github.com/bakjos/protoreflect@v1.9.2/internal/testutil/asserts.go (about)

     1  package testutil
     2  
     3  import (
     4  	"bytes"
     5  	"math"
     6  	"reflect"
     7  	"testing"
     8  )
     9  
    10  // Ceq is a custom equals check; the given function returns true if its arguments are equal
    11  func Ceq(t *testing.T, expected, actual interface{}, eq func(a, b interface{}) bool, context ...interface{}) bool {
    12  	t.Helper()
    13  	e := eq(expected, actual)
    14  	Require(t, e, mergeContext(context, "Expecting %v (%v), got %v (%v)", expected, reflect.TypeOf(expected), actual, reflect.TypeOf(actual))...)
    15  	return e
    16  }
    17  
    18  // Cneq is a custom not-equals check; the given function returns true if its arguments are equal
    19  func Cneq(t *testing.T, unexpected, actual interface{}, eq func(a, b interface{}) bool, context ...interface{}) bool {
    20  	t.Helper()
    21  	ne := !eq(unexpected, actual)
    22  	Require(t, ne, mergeContext(context, "Value should not be %v (%v)", unexpected, reflect.TypeOf(unexpected))...)
    23  	return ne
    24  }
    25  
    26  // Require is an assertion that logs a failure if its given argument is not true
    27  func Require(t *testing.T, condition bool, context ...interface{}) {
    28  	t.Helper()
    29  	if !condition {
    30  		if len(context) == 0 {
    31  			t.Fatalf("Assertion failed")
    32  		} else {
    33  			msg := context[0].(string)
    34  			// if any args were deferred (e.g. a function instead of a value), get those args now
    35  			args := make([]interface{}, len(context)-1)
    36  			for i, a := range context[1:] {
    37  				rv := reflect.ValueOf(a)
    38  				if rv.Kind() == reflect.Func {
    39  					a = rv.Call([]reflect.Value{})[0].Interface()
    40  				}
    41  				args[i] = a
    42  			}
    43  			t.Fatalf(msg, args...)
    44  		}
    45  	}
    46  }
    47  
    48  func mergeContext(context []interface{}, msg string, msgArgs ...interface{}) []interface{} {
    49  	if len(context) == 0 {
    50  		ret := make([]interface{}, 0, len(msgArgs)+1)
    51  		ret = append(ret, msg)
    52  		ret = append(ret, msgArgs...)
    53  		return ret
    54  	} else {
    55  		ret := make([]interface{}, 0, len(context)+len(msgArgs))
    56  		ret = append(ret, msg+": "+context[0].(string))
    57  		ret = append(ret, msgArgs...)
    58  		ret = append(ret, context[1:]...)
    59  		return ret
    60  	}
    61  }
    62  
    63  // Ok asserts that the given error is nil
    64  func Ok(t *testing.T, err error, context ...interface{}) {
    65  	t.Helper()
    66  	Require(t, err == nil, mergeContext(context, "Unexpected error: %s", func() interface{} { return err.Error() })...)
    67  }
    68  
    69  // Nok asserts that the given error is not nil
    70  func Nok(t *testing.T, err error, context ...interface{}) {
    71  	t.Helper()
    72  	Require(t, err != nil, mergeContext(context, "Expected error but got none")...)
    73  }
    74  
    75  // Eq asserts that the given two values are equal
    76  func Eq(t *testing.T, expected, actual interface{}, context ...interface{}) bool {
    77  	t.Helper()
    78  	return Ceq(t, expected, actual, eqany, context...)
    79  }
    80  
    81  // Neq asserts that the given two values are not equal
    82  func Neq(t *testing.T, unexpected, actual interface{}, context ...interface{}) bool {
    83  	t.Helper()
    84  	return Cneq(t, unexpected, actual, eqany, context...)
    85  }
    86  
    87  // default equality test and helpers
    88  
    89  func eqany(expected, actual interface{}) bool {
    90  	if expected == nil && actual == nil {
    91  		return true
    92  	}
    93  	if expected == nil || actual == nil {
    94  		return false
    95  	}
    96  
    97  	// We don't want reflect.DeepEquals because of its recursive nature. So we need
    98  	// a custom compare for slices and maps. Two slices are equal if they have the
    99  	// same number of elements and the elements at the same index are equal to each
   100  	// other. Two maps are equal if their key sets are the same and the corresponding
   101  	// values are equal. (The relationship is not recursive,  slices or maps that
   102  	// contain other slices or maps can't be tested.)
   103  	et := reflect.TypeOf(expected)
   104  
   105  	if et.Kind() == reflect.Slice {
   106  		return eqslice(reflect.ValueOf(expected), reflect.ValueOf(actual))
   107  	} else if et.Kind() == reflect.Map {
   108  		return eqmap(reflect.ValueOf(expected), reflect.ValueOf(actual))
   109  	} else {
   110  		return eqscalar(expected, actual)
   111  	}
   112  }
   113  
   114  func eqscalar(expected, actual interface{}) bool {
   115  	// special-case simple equality for []byte (since slices aren't directly comparable)
   116  	if e, ok := expected.([]byte); ok {
   117  		a, ok := actual.([]byte)
   118  		return ok && bytes.Equal(e, a)
   119  	}
   120  	// and special-cases to handle NaN
   121  	if e, ok := expected.(float32); ok && math.IsNaN(float64(e)) {
   122  		a, ok := actual.(float32)
   123  		return ok && math.IsNaN(float64(a))
   124  	}
   125  	if e, ok := expected.(float64); ok && math.IsNaN(e) {
   126  		a, ok := actual.(float64)
   127  		return ok && math.IsNaN(a)
   128  	}
   129  	// simple logic for everything else
   130  	return expected == actual
   131  }
   132  
   133  func eqslice(expected, actual reflect.Value) bool {
   134  	if expected.Len() != actual.Len() {
   135  		return false
   136  	}
   137  	for i := 0; i < expected.Len(); i++ {
   138  		e := expected.Index(i).Interface()
   139  		a := actual.Index(i).Interface()
   140  		if !eqscalar(e, a) {
   141  			return false
   142  		}
   143  	}
   144  	return true
   145  }
   146  
   147  func eqmap(expected, actual reflect.Value) bool {
   148  	if expected.Len() != actual.Len() {
   149  		return false
   150  	}
   151  	for _, k := range expected.MapKeys() {
   152  		e := expected.MapIndex(k)
   153  		a := actual.MapIndex(k)
   154  		if !a.IsValid() {
   155  			return false
   156  		}
   157  		if !eqscalar(e.Interface(), a.Interface()) {
   158  			return false
   159  		}
   160  	}
   161  	return true
   162  }