github.com/markusbkk/elvish@v0.0.0-20231204143114-91dc52438621/pkg/tt/tt.go (about)

     1  // Package tt supports table-driven tests with little boilerplate.
     2  //
     3  // See the test case for this package for example usage.
     4  package tt
     5  
     6  import (
     7  	"bytes"
     8  	"fmt"
     9  	"reflect"
    10  )
    11  
    12  // Table represents a test table.
    13  type Table []*Case
    14  
    15  // Case represents a test case. It is created by the C function, and offers
    16  // setters that augment and return itself; those calls can be chained like
    17  // C(...).Rets(...).
    18  type Case struct {
    19  	args         []interface{}
    20  	retsMatchers [][]interface{}
    21  }
    22  
    23  // Args returns a new Case with the given arguments.
    24  func Args(args ...interface{}) *Case {
    25  	return &Case{args: args}
    26  }
    27  
    28  // Rets modifies the test case so that it requires the return values to match
    29  // the given values. It returns the receiver. The arguments may implement the
    30  // Matcher interface, in which case its Match method is called with the actual
    31  // return value. Otherwise, reflect.DeepEqual is used to determine matches.
    32  func (c *Case) Rets(matchers ...interface{}) *Case {
    33  	c.retsMatchers = append(c.retsMatchers, matchers)
    34  	return c
    35  }
    36  
    37  // FnToTest describes a function to test.
    38  type FnToTest struct {
    39  	name    string
    40  	body    interface{}
    41  	argsFmt string
    42  	retsFmt string
    43  }
    44  
    45  // Fn makes a new FnToTest with the given function name and body.
    46  func Fn(name string, body interface{}) *FnToTest {
    47  	return &FnToTest{name: name, body: body}
    48  }
    49  
    50  // ArgsFmt sets the string for formatting arguments in test error messages, and
    51  // return fn itself.
    52  func (fn *FnToTest) ArgsFmt(s string) *FnToTest {
    53  	fn.argsFmt = s
    54  	return fn
    55  }
    56  
    57  // RetsFmt sets the string for formatting return values in test error messages,
    58  // and return fn itself.
    59  func (fn *FnToTest) RetsFmt(s string) *FnToTest {
    60  	fn.retsFmt = s
    61  	return fn
    62  }
    63  
    64  // T is the interface for accessing testing.T.
    65  type T interface {
    66  	Helper()
    67  	Errorf(format string, args ...interface{})
    68  }
    69  
    70  // Test tests a function against test cases.
    71  func Test(t T, fn *FnToTest, tests Table) {
    72  	t.Helper()
    73  	for _, test := range tests {
    74  		rets := call(fn.body, test.args)
    75  		for _, retsMatcher := range test.retsMatchers {
    76  			if !match(retsMatcher, rets) {
    77  				var argsString, retsString, wantRetsString string
    78  				if fn.argsFmt == "" {
    79  					argsString = sprintArgs(test.args...)
    80  				} else {
    81  					argsString = fmt.Sprintf(fn.argsFmt, test.args...)
    82  				}
    83  				if fn.retsFmt == "" {
    84  					retsString = sprintRets(rets...)
    85  					wantRetsString = sprintRets(retsMatcher...)
    86  				} else {
    87  					retsString = fmt.Sprintf(fn.retsFmt, rets...)
    88  					wantRetsString = fmt.Sprintf(fn.retsFmt, retsMatcher...)
    89  				}
    90  				t.Errorf("%s(%s) -> %s, want %s", fn.name, argsString, retsString, wantRetsString)
    91  			}
    92  		}
    93  	}
    94  }
    95  
    96  // RetValue is an empty interface used in the Matcher interface.
    97  type RetValue interface{}
    98  
    99  // Matcher wraps the Match method.
   100  type Matcher interface {
   101  	// Match reports whether a return value is considered a match. The argument
   102  	// is of type RetValue so that it cannot be implemented accidentally.
   103  	Match(RetValue) bool
   104  }
   105  
   106  // Any is a Matcher that matches any value.
   107  var Any Matcher = anyMatcher{}
   108  
   109  type anyMatcher struct{}
   110  
   111  func (anyMatcher) Match(RetValue) bool { return true }
   112  
   113  func match(matchers, actual []interface{}) bool {
   114  	for i, matcher := range matchers {
   115  		if !matchOne(matcher, actual[i]) {
   116  			return false
   117  		}
   118  	}
   119  	return true
   120  }
   121  
   122  func matchOne(m, a interface{}) bool {
   123  	if m, ok := m.(Matcher); ok {
   124  		return m.Match(a)
   125  	}
   126  	return reflect.DeepEqual(m, a)
   127  }
   128  
   129  func sprintArgs(args ...interface{}) string {
   130  	return sprintCommaDelimited(args...)
   131  }
   132  
   133  func sprintRets(rets ...interface{}) string {
   134  	if len(rets) == 1 {
   135  		return fmt.Sprint(rets[0])
   136  	}
   137  	return "(" + sprintCommaDelimited(rets...) + ")"
   138  }
   139  
   140  func sprintCommaDelimited(args ...interface{}) string {
   141  	var b bytes.Buffer
   142  	for i, arg := range args {
   143  		if i > 0 {
   144  			b.WriteString(", ")
   145  		}
   146  		fmt.Fprint(&b, arg)
   147  	}
   148  	return b.String()
   149  }
   150  
   151  func call(fn interface{}, args []interface{}) []interface{} {
   152  	argsReflect := make([]reflect.Value, len(args))
   153  	for i, arg := range args {
   154  		if arg == nil {
   155  			// reflect.ValueOf(nil) returns a zero Value, but this is not what
   156  			// we want. Work around this by taking the ValueOf a pointer to nil
   157  			// and then get the Elem.
   158  			// TODO(xiaq): This is now always using a nil value with type
   159  			// interface{}. For more usability, inspect the type of fn to see
   160  			// which type of nil this argument should be.
   161  			var v interface{}
   162  			argsReflect[i] = reflect.ValueOf(&v).Elem()
   163  		} else {
   164  			argsReflect[i] = reflect.ValueOf(arg)
   165  		}
   166  	}
   167  	retsReflect := reflect.ValueOf(fn).Call(argsReflect)
   168  	rets := make([]interface{}, len(retsReflect))
   169  	for i, retReflect := range retsReflect {
   170  		rets[i] = retReflect.Interface()
   171  	}
   172  	return rets
   173  }