github.com/hoveychen/protoreflect@v1.4.7-0.20221103114119-0b4b3385ec76/internal/testutil/asserts.go (about)

     1  package testutil
     2  
     3  import (
     4  	"bytes"
     5  	"fmt"
     6  	"math"
     7  	"os"
     8  	"reflect"
     9  	"runtime"
    10  	"strings"
    11  	"testing"
    12  )
    13  
    14  // Ceq is a custom equals check; the given function returns true if its arguments are equal
    15  func Ceq(t *testing.T, expected, actual interface{}, eq func(a, b interface{}) bool, context ...interface{}) bool {
    16  	return ceq(getCaller(), t, expected, actual, eq, context)
    17  }
    18  
    19  func ceq(caller string, t *testing.T, expected, actual interface{}, eq func(a, b interface{}) bool, context []interface{}) bool {
    20  	e := eq(expected, actual)
    21  	require(caller, t, e, mergeContext(context, "Expecting %v (%v), got %v (%v)", expected, reflect.TypeOf(expected), actual, reflect.TypeOf(actual)))
    22  	return e
    23  }
    24  
    25  // Cneq is a custom not-equals check; the given function returns true if its arguments are equal
    26  func Cneq(t *testing.T, unexpected, actual interface{}, eq func(a, b interface{}) bool, context ...interface{}) bool {
    27  	return cneq(getCaller(), t, unexpected, actual, eq, context)
    28  }
    29  
    30  func cneq(caller string, t *testing.T, unexpected, actual interface{}, eq func(a, b interface{}) bool, context []interface{}) bool {
    31  	ne := !eq(unexpected, actual)
    32  	require(caller, t, ne, mergeContext(context, "Value should not be %v (%v)", unexpected, reflect.TypeOf(unexpected)))
    33  	return ne
    34  }
    35  
    36  // Require is an assertion that logs a failure if its given argument is not true
    37  func Require(t *testing.T, condition bool, context ...interface{}) {
    38  	require(getCaller(), t, condition, context)
    39  }
    40  
    41  func require(caller string, t *testing.T, condition bool, context []interface{}) {
    42  	if !condition {
    43  		if len(context) == 0 {
    44  			t.Fatalf("%s: Assertion failed", caller)
    45  		} else {
    46  			msg := context[0].(string)
    47  			// if any args were deferred (e.g. a function instead of a value), get those args now
    48  			args := make([]interface{}, len(context)-1)
    49  			for i, a := range context[1:] {
    50  				rv := reflect.ValueOf(a)
    51  				if rv.Kind() == reflect.Func {
    52  					a = rv.Call([]reflect.Value{})[0].Interface()
    53  				}
    54  				args[i] = a
    55  			}
    56  			t.Fatalf("%s: %s", caller, fmt.Sprintf(msg, args...))
    57  		}
    58  	}
    59  }
    60  
    61  func mergeContext(context []interface{}, msg string, msgArgs ...interface{}) []interface{} {
    62  	if len(context) == 0 {
    63  		ret := make([]interface{}, len(msgArgs)+1)
    64  		ret[0] = msg
    65  		for i, a := range msgArgs {
    66  			ret[i+1] = a
    67  		}
    68  		return ret
    69  	} else {
    70  		ret := make([]interface{}, len(msgArgs)+2)
    71  		ret[0] = msg + ": %s"
    72  		for i, a := range msgArgs {
    73  			ret[i+1] = a
    74  		}
    75  		ret[len(ret)-1] = func() string {
    76  			f := context[0].(string)
    77  			return fmt.Sprintf(f, context[1:]...)
    78  		}
    79  		return ret
    80  	}
    81  }
    82  
    83  func getCaller() string {
    84  	pc, file, line, ok := runtime.Caller(2)
    85  	if !ok {
    86  		return "?"
    87  	}
    88  	fn := runtime.FuncForPC(pc)
    89  	var fnName string
    90  	if fn == nil {
    91  		fnName = "?"
    92  	} else {
    93  		fnName = fn.Name()
    94  	}
    95  	return fmt.Sprintf("%s(%s:%d)", lastComponents(fnName, 1), lastComponents(file, 2), line)
    96  }
    97  
    98  const pathSep = string(os.PathSeparator)
    99  
   100  func lastComponents(s string, count int) string {
   101  	ss := s
   102  	var i int
   103  	for count > 0 {
   104  		i = strings.LastIndex(ss, pathSep)
   105  		if i < 0 {
   106  			return s
   107  		}
   108  		count--
   109  		ss = ss[:i]
   110  	}
   111  	return s[i+1:]
   112  }
   113  
   114  // Ok asserts that the given error is nil
   115  func Ok(t *testing.T, err error, context ...interface{}) {
   116  	require(getCaller(), t, err == nil, mergeContext(context, "Unexpected error: %s", func() interface{} { return err.Error() }))
   117  }
   118  
   119  // Nok asserts that the given error is not nil
   120  func Nok(t *testing.T, err error, context ...interface{}) {
   121  	require(getCaller(), t, err != nil, mergeContext(context, "Expected error but got none"))
   122  }
   123  
   124  // Eq asserts that the given two values are equal
   125  func Eq(t *testing.T, expected, actual interface{}, context ...interface{}) bool {
   126  	return ceq(getCaller(), t, expected, actual, eqany, context)
   127  }
   128  
   129  // Neq asserts that the given two values are not equal
   130  func Neq(t *testing.T, unexpected, actual interface{}, context ...interface{}) bool {
   131  	return cneq(getCaller(), t, unexpected, actual, eqany, context)
   132  }
   133  
   134  // default equality test and helpers
   135  
   136  func eqany(expected, actual interface{}) bool {
   137  	if expected == nil && actual == nil {
   138  		return true
   139  	}
   140  	if expected == nil || actual == nil {
   141  		return false
   142  	}
   143  
   144  	// We don't want reflect.DeepEquals because of its recursive nature. So we need
   145  	// a custom compare for slices and maps. Two slices are equal if they have the
   146  	// same number of elements and the elements at the same index are equal to each
   147  	// other. Two maps are equal if their key sets are the same and the corresponding
   148  	// values are equal. (The relationship is not recursive,  slices or maps that
   149  	// contain other slices or maps can't be tested.)
   150  	et := reflect.TypeOf(expected)
   151  
   152  	if et.Kind() == reflect.Slice {
   153  		return eqslice(reflect.ValueOf(expected), reflect.ValueOf(actual))
   154  	} else if et.Kind() == reflect.Map {
   155  		return eqmap(reflect.ValueOf(expected), reflect.ValueOf(actual))
   156  	} else {
   157  		return eqscalar(expected, actual)
   158  	}
   159  }
   160  
   161  func eqscalar(expected, actual interface{}) bool {
   162  	// special-case simple equality for []byte (since slices aren't directly comparable)
   163  	if e, ok := expected.([]byte); ok {
   164  		a, ok := actual.([]byte)
   165  		return ok && bytes.Equal(e, a)
   166  	}
   167  	// and special-cases to handle NaN
   168  	if e, ok := expected.(float32); ok && math.IsNaN(float64(e)) {
   169  		a, ok := actual.(float32)
   170  		return ok && math.IsNaN(float64(a))
   171  	}
   172  	if e, ok := expected.(float64); ok && math.IsNaN(e) {
   173  		a, ok := actual.(float64)
   174  		return ok && math.IsNaN(a)
   175  	}
   176  	// simple logic for everything else
   177  	return expected == actual
   178  }
   179  
   180  func eqslice(expected, actual reflect.Value) bool {
   181  	if expected.Len() != actual.Len() {
   182  		return false
   183  	}
   184  	for i := 0; i < expected.Len(); i++ {
   185  		e := expected.Index(i).Interface()
   186  		a := actual.Index(i).Interface()
   187  		if !eqscalar(e, a) {
   188  			return false
   189  		}
   190  	}
   191  	return true
   192  }
   193  
   194  func eqmap(expected, actual reflect.Value) bool {
   195  	if expected.Len() != actual.Len() {
   196  		return false
   197  	}
   198  	for _, k := range expected.MapKeys() {
   199  		e := expected.MapIndex(k)
   200  		a := actual.MapIndex(k)
   201  		if !a.IsValid() {
   202  			return false
   203  		}
   204  		if !eqscalar(e.Interface(), a.Interface()) {
   205  			return false
   206  		}
   207  	}
   208  	return true
   209  }