src.elv.sh@v0.21.0-dev.0.20240515223629-06979efb9a2a/pkg/tt/tt.go (about)

     1  // Package tt supports table-driven tests with little boilerplate.
     2  //
     3  // A typical use of this package looks like this:
     4  //
     5  //	// Function being tested
     6  //	func Neg(i int) { return -i }
     7  //
     8  //	func TestNeg(t *testing.T) {
     9  //		Test(t, Neg,
    10  //			// Unnamed test case
    11  //			Args(1).Rets(-1),
    12  //			// Named test case
    13  //			It("returns 0 for 0").Args(0).Rets(0),
    14  //		)
    15  //	}
    16  package tt
    17  
    18  import (
    19  	"fmt"
    20  	"path/filepath"
    21  	"reflect"
    22  	"runtime"
    23  	"strings"
    24  	"testing"
    25  
    26  	"github.com/google/go-cmp/cmp"
    27  )
    28  
    29  // Case represents a test case. It has setter methods that augment and return
    30  // itself, so they can be chained like It(...).Args(...).Rets(...).
    31  type Case struct {
    32  	fileAndLine  string
    33  	desc         string
    34  	args         []any
    35  	retsMatchers [][]any
    36  }
    37  
    38  // It returns a Case with the given text description.
    39  func It(desc string) *Case {
    40  	return &Case{fileAndLine: fileAndLine(2), desc: desc}
    41  }
    42  
    43  // Args is equivalent to It("").args(...). It is useful when the test case is
    44  // trivial and doesn't need a description; for more complex or interesting test
    45  // cases, use [It] instead.
    46  func Args(args ...any) *Case {
    47  	return &Case{fileAndLine: fileAndLine(2), args: args}
    48  }
    49  
    50  func fileAndLine(skip int) string {
    51  	_, filename, line, _ := runtime.Caller(skip)
    52  	return fmt.Sprintf("%s:%d", filepath.Base(filename), line)
    53  }
    54  
    55  // Args modifies the Case to pass the given arguments. It returns the receiver.
    56  func (c *Case) Args(args ...any) *Case {
    57  	c.args = args
    58  	return c
    59  }
    60  
    61  // Rets modifies the Case to expect the given return values. It returns the
    62  // receiver.
    63  //
    64  // The arguments may implement the [Matcher] interface, in which case its Match
    65  // method is called with the actual return value. Otherwise, [reflect.DeepEqual]
    66  // is used to determine matches.
    67  func (c *Case) Rets(matchers ...any) *Case {
    68  	c.retsMatchers = append(c.retsMatchers, matchers)
    69  	return c
    70  }
    71  
    72  // FnDescriptor describes a function to test. It has setter methods that augment
    73  // and return itself, so they can be chained like
    74  // Fn(...).Named(...).ArgsFmt(...).
    75  type FnDescriptor struct {
    76  	name    string
    77  	body    any
    78  	argsFmt string
    79  }
    80  
    81  // Fn creates a FnDescriptor for the given function.
    82  func Fn(body any) *FnDescriptor {
    83  	return &FnDescriptor{body: body}
    84  }
    85  
    86  // Named sets the name of the function. This is only necessary for methods and
    87  // local closures; package-level functions will have their name automatically
    88  // inferred via reflection. It returns the receiver.
    89  func (fn *FnDescriptor) Named(name string) *FnDescriptor {
    90  	fn.name = name
    91  	return fn
    92  }
    93  
    94  // ArgsFmt sets the string for formatting arguments in test error messages. It
    95  // returns the receiver.
    96  func (fn *FnDescriptor) ArgsFmt(s string) *FnDescriptor {
    97  	fn.argsFmt = s
    98  	return fn
    99  }
   100  
   101  // Test tests fn against the given Case instances.
   102  //
   103  // The fn argument may be the function itself or an explicit [FnDescriptor], the
   104  // former case being equivalent to passing Fn(fn).
   105  func Test(t *testing.T, fn any, tests ...*Case) {
   106  	testInner[*testing.T](t, fn, tests...)
   107  }
   108  
   109  // Instead of using [*testing.T] directly, the inner implementation uses two
   110  // interfaces so that it can be mocked. We need two interfaces because
   111  // type parameters can't refer to the type itself.
   112  
   113  type testRunner[T subtestRunner] interface {
   114  	Helper()
   115  	Run(name string, f func(t T)) bool
   116  }
   117  
   118  type subtestRunner interface {
   119  	Errorf(format string, args ...any)
   120  }
   121  
   122  func testInner[T subtestRunner](t testRunner[T], fn any, tests ...*Case) {
   123  	t.Helper()
   124  	var fnd *FnDescriptor
   125  	switch fn := fn.(type) {
   126  	case *FnDescriptor:
   127  		fnd = &FnDescriptor{}
   128  		*fnd = *fn
   129  	default:
   130  		fnd = Fn(fn)
   131  	}
   132  	if fnd.name == "" {
   133  		// Use reflection to discover the function's name.
   134  		name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name()
   135  		// Tests are usually restricted to functions in the same package, so
   136  		// elide the package name.
   137  		if i := strings.LastIndexByte(name, '.'); i != -1 {
   138  			name = name[i+1:]
   139  		}
   140  		fnd.name = name
   141  	}
   142  
   143  	for _, test := range tests {
   144  		t.Run(test.desc, func(t T) {
   145  			rets := call(fnd.body, test.args)
   146  			for _, retsMatcher := range test.retsMatchers {
   147  				if !match(retsMatcher, rets) {
   148  					var args string
   149  					if fnd.argsFmt == "" {
   150  						args = sprintArgs(test.args...)
   151  					} else {
   152  						args = fmt.Sprintf(fnd.argsFmt, test.args...)
   153  					}
   154  					var diff string
   155  					if len(retsMatcher) == 1 && len(rets) == 1 {
   156  						diff = cmp.Diff(retsMatcher[0], rets[0], cmpopt)
   157  					} else {
   158  						diff = cmp.Diff(retsMatcher, rets, cmpopt)
   159  					}
   160  					t.Errorf("%s: %s(%s) returns (-want +got):\n%s",
   161  						test.fileAndLine, fnd.name, args, diff)
   162  				}
   163  			}
   164  		})
   165  	}
   166  }
   167  
   168  // RetValue is an empty interface used in the [Matcher] interface.
   169  type RetValue any
   170  
   171  // Matcher wraps the Match method.
   172  //
   173  // Values that implement this interface can be passed to [*Case.Rets] to control
   174  // the matching algorithm for return values.
   175  type Matcher interface {
   176  	// Match reports whether a return value is considered a match. The argument
   177  	// is of type RetValue so that it cannot be implemented accidentally.
   178  	Match(RetValue) bool
   179  }
   180  
   181  // Any is a Matcher that matches any value.
   182  var Any Matcher = anyMatcher{}
   183  
   184  type anyMatcher struct{}
   185  
   186  func (anyMatcher) Match(RetValue) bool { return true }
   187  
   188  func match(matchers, actual []any) bool {
   189  	for i, matcher := range matchers {
   190  		if !matchOne(matcher, actual[i]) {
   191  			return false
   192  		}
   193  	}
   194  	return true
   195  }
   196  
   197  func matchOne(m, a any) bool {
   198  	if m, ok := m.(Matcher); ok {
   199  		return m.Match(a)
   200  	}
   201  	return reflect.DeepEqual(m, a)
   202  }
   203  
   204  func sprintArgs(args ...any) string {
   205  	var b strings.Builder
   206  	for i, arg := range args {
   207  		if i > 0 {
   208  			b.WriteString(", ")
   209  		}
   210  		fmt.Fprint(&b, arg)
   211  	}
   212  	return b.String()
   213  }
   214  
   215  func call(fn any, args []any) []any {
   216  	argsReflect := make([]reflect.Value, len(args))
   217  	for i, arg := range args {
   218  		if arg == nil {
   219  			// reflect.ValueOf(nil) returns a zero Value, but this is not what
   220  			// we want. Work around this by taking the ValueOf a pointer to nil
   221  			// and then get the Elem.
   222  			// TODO(xiaq): This is now always using a nil value with type
   223  			// interface{}. For more usability, inspect the type of fn to see
   224  			// which type of nil this argument should be.
   225  			var v any
   226  			argsReflect[i] = reflect.ValueOf(&v).Elem()
   227  		} else {
   228  			argsReflect[i] = reflect.ValueOf(arg)
   229  		}
   230  	}
   231  	retsReflect := reflect.ValueOf(fn).Call(argsReflect)
   232  	rets := make([]any, len(retsReflect))
   233  	for i, retReflect := range retsReflect {
   234  		rets[i] = retReflect.Interface()
   235  	}
   236  	return rets
   237  }