src.elv.sh@v0.21.0-dev.0.20240515223629-06979efb9a2a/pkg/tt/tt.go (about) 1 // Package tt supports table-driven tests with little boilerplate. 2 // 3 // A typical use of this package looks like this: 4 // 5 // // Function being tested 6 // func Neg(i int) { return -i } 7 // 8 // func TestNeg(t *testing.T) { 9 // Test(t, Neg, 10 // // Unnamed test case 11 // Args(1).Rets(-1), 12 // // Named test case 13 // It("returns 0 for 0").Args(0).Rets(0), 14 // ) 15 // } 16 package tt 17 18 import ( 19 "fmt" 20 "path/filepath" 21 "reflect" 22 "runtime" 23 "strings" 24 "testing" 25 26 "github.com/google/go-cmp/cmp" 27 ) 28 29 // Case represents a test case. It has setter methods that augment and return 30 // itself, so they can be chained like It(...).Args(...).Rets(...). 31 type Case struct { 32 fileAndLine string 33 desc string 34 args []any 35 retsMatchers [][]any 36 } 37 38 // It returns a Case with the given text description. 39 func It(desc string) *Case { 40 return &Case{fileAndLine: fileAndLine(2), desc: desc} 41 } 42 43 // Args is equivalent to It("").args(...). It is useful when the test case is 44 // trivial and doesn't need a description; for more complex or interesting test 45 // cases, use [It] instead. 46 func Args(args ...any) *Case { 47 return &Case{fileAndLine: fileAndLine(2), args: args} 48 } 49 50 func fileAndLine(skip int) string { 51 _, filename, line, _ := runtime.Caller(skip) 52 return fmt.Sprintf("%s:%d", filepath.Base(filename), line) 53 } 54 55 // Args modifies the Case to pass the given arguments. It returns the receiver. 56 func (c *Case) Args(args ...any) *Case { 57 c.args = args 58 return c 59 } 60 61 // Rets modifies the Case to expect the given return values. It returns the 62 // receiver. 63 // 64 // The arguments may implement the [Matcher] interface, in which case its Match 65 // method is called with the actual return value. Otherwise, [reflect.DeepEqual] 66 // is used to determine matches. 67 func (c *Case) Rets(matchers ...any) *Case { 68 c.retsMatchers = append(c.retsMatchers, matchers) 69 return c 70 } 71 72 // FnDescriptor describes a function to test. It has setter methods that augment 73 // and return itself, so they can be chained like 74 // Fn(...).Named(...).ArgsFmt(...). 75 type FnDescriptor struct { 76 name string 77 body any 78 argsFmt string 79 } 80 81 // Fn creates a FnDescriptor for the given function. 82 func Fn(body any) *FnDescriptor { 83 return &FnDescriptor{body: body} 84 } 85 86 // Named sets the name of the function. This is only necessary for methods and 87 // local closures; package-level functions will have their name automatically 88 // inferred via reflection. It returns the receiver. 89 func (fn *FnDescriptor) Named(name string) *FnDescriptor { 90 fn.name = name 91 return fn 92 } 93 94 // ArgsFmt sets the string for formatting arguments in test error messages. It 95 // returns the receiver. 96 func (fn *FnDescriptor) ArgsFmt(s string) *FnDescriptor { 97 fn.argsFmt = s 98 return fn 99 } 100 101 // Test tests fn against the given Case instances. 102 // 103 // The fn argument may be the function itself or an explicit [FnDescriptor], the 104 // former case being equivalent to passing Fn(fn). 105 func Test(t *testing.T, fn any, tests ...*Case) { 106 testInner[*testing.T](t, fn, tests...) 107 } 108 109 // Instead of using [*testing.T] directly, the inner implementation uses two 110 // interfaces so that it can be mocked. We need two interfaces because 111 // type parameters can't refer to the type itself. 112 113 type testRunner[T subtestRunner] interface { 114 Helper() 115 Run(name string, f func(t T)) bool 116 } 117 118 type subtestRunner interface { 119 Errorf(format string, args ...any) 120 } 121 122 func testInner[T subtestRunner](t testRunner[T], fn any, tests ...*Case) { 123 t.Helper() 124 var fnd *FnDescriptor 125 switch fn := fn.(type) { 126 case *FnDescriptor: 127 fnd = &FnDescriptor{} 128 *fnd = *fn 129 default: 130 fnd = Fn(fn) 131 } 132 if fnd.name == "" { 133 // Use reflection to discover the function's name. 134 name := runtime.FuncForPC(reflect.ValueOf(fn).Pointer()).Name() 135 // Tests are usually restricted to functions in the same package, so 136 // elide the package name. 137 if i := strings.LastIndexByte(name, '.'); i != -1 { 138 name = name[i+1:] 139 } 140 fnd.name = name 141 } 142 143 for _, test := range tests { 144 t.Run(test.desc, func(t T) { 145 rets := call(fnd.body, test.args) 146 for _, retsMatcher := range test.retsMatchers { 147 if !match(retsMatcher, rets) { 148 var args string 149 if fnd.argsFmt == "" { 150 args = sprintArgs(test.args...) 151 } else { 152 args = fmt.Sprintf(fnd.argsFmt, test.args...) 153 } 154 var diff string 155 if len(retsMatcher) == 1 && len(rets) == 1 { 156 diff = cmp.Diff(retsMatcher[0], rets[0], cmpopt) 157 } else { 158 diff = cmp.Diff(retsMatcher, rets, cmpopt) 159 } 160 t.Errorf("%s: %s(%s) returns (-want +got):\n%s", 161 test.fileAndLine, fnd.name, args, diff) 162 } 163 } 164 }) 165 } 166 } 167 168 // RetValue is an empty interface used in the [Matcher] interface. 169 type RetValue any 170 171 // Matcher wraps the Match method. 172 // 173 // Values that implement this interface can be passed to [*Case.Rets] to control 174 // the matching algorithm for return values. 175 type Matcher interface { 176 // Match reports whether a return value is considered a match. The argument 177 // is of type RetValue so that it cannot be implemented accidentally. 178 Match(RetValue) bool 179 } 180 181 // Any is a Matcher that matches any value. 182 var Any Matcher = anyMatcher{} 183 184 type anyMatcher struct{} 185 186 func (anyMatcher) Match(RetValue) bool { return true } 187 188 func match(matchers, actual []any) bool { 189 for i, matcher := range matchers { 190 if !matchOne(matcher, actual[i]) { 191 return false 192 } 193 } 194 return true 195 } 196 197 func matchOne(m, a any) bool { 198 if m, ok := m.(Matcher); ok { 199 return m.Match(a) 200 } 201 return reflect.DeepEqual(m, a) 202 } 203 204 func sprintArgs(args ...any) string { 205 var b strings.Builder 206 for i, arg := range args { 207 if i > 0 { 208 b.WriteString(", ") 209 } 210 fmt.Fprint(&b, arg) 211 } 212 return b.String() 213 } 214 215 func call(fn any, args []any) []any { 216 argsReflect := make([]reflect.Value, len(args)) 217 for i, arg := range args { 218 if arg == nil { 219 // reflect.ValueOf(nil) returns a zero Value, but this is not what 220 // we want. Work around this by taking the ValueOf a pointer to nil 221 // and then get the Elem. 222 // TODO(xiaq): This is now always using a nil value with type 223 // interface{}. For more usability, inspect the type of fn to see 224 // which type of nil this argument should be. 225 var v any 226 argsReflect[i] = reflect.ValueOf(&v).Elem() 227 } else { 228 argsReflect[i] = reflect.ValueOf(arg) 229 } 230 } 231 retsReflect := reflect.ValueOf(fn).Call(argsReflect) 232 rets := make([]any, len(retsReflect)) 233 for i, retReflect := range retsReflect { 234 rets[i] = retReflect.Interface() 235 } 236 return rets 237 }