cuelang.org/go@v0.10.1/internal/tdtest/tdtest.go (about)

     1  // Copyright 2023 CUE Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  // Package tdtest provides support for table-driven testing.
    16  //
    17  // Features include automatically updating of test values, automatic error
    18  // message generation, and singling out single tests to run.
    19  //
    20  // Auto updating fields is only supported for fields that are scalar types:
    21  // string, bool, int*, and uint*. If the field is a string, the "actual" value
    22  // may be any Go value that can meaningfully be printed with fmt.Sprint.
    23  package tdtest
    24  
    25  import (
    26  	"fmt"
    27  	"go/token"
    28  	"path/filepath"
    29  	"reflect"
    30  	"runtime"
    31  	"strings"
    32  	"testing"
    33  )
    34  
    35  // TODO:
    36  // - make this a public package at some point.
    37  // - add tests. Maybe adding Examples is sufficient.
    38  // - use text-based modification, instead of astutil. The latter is too brittle.
    39  // - allow updating position-based, instead of named, fields.
    40  // - implement skip, maybe match
    41  // - make name field explicit, i.e. Name("name"), field tag, or tdtest.Name type.
    42  // - allow "skip" field. Again either SkipName("skip"), tag, or Skip type.
    43  // - allow for tdtest:"noupdate" field tag.
    44  // - Option: allow ignore field that lists a set of fields to not be tested
    45  //   for that particular test case: ignore: tdtest.Ignore("want1", "want2")
    46  //
    47  
    48  // UpdateTests defines whether tests should be updated by default.
    49  // This can be overridden on an individual basis using T.Update.
    50  var UpdateTests = false
    51  
    52  // set is the set of tests to run.
    53  type set[TC any] struct {
    54  	t *testing.T
    55  
    56  	table []TC
    57  
    58  	updateEnabled bool
    59  	info          *info
    60  }
    61  
    62  // Run runs the given function for each (selected) element in the table.
    63  // TC must be a struct type. If that has a string field named "name",
    64  // that value will be used to name the associated subtest.
    65  func Run[TC any](t *testing.T, table []TC, fn func(t *T, tc *TC)) {
    66  	s := &set[TC]{
    67  		t:             t,
    68  		table:         table,
    69  		updateEnabled: UpdateTests,
    70  	}
    71  	for i := range s.table {
    72  		name := fmt.Sprint(i)
    73  
    74  		x := reflect.ValueOf(s.table[i]).FieldByName("name")
    75  		if x.Kind() == reflect.String {
    76  			name += "/" + x.String()
    77  		}
    78  
    79  		s.t.Run(name, func(t *testing.T) {
    80  			tt := &T{
    81  				T:             t,
    82  				iter:          i,
    83  				infoSrc:       s,
    84  				updateEnabled: s.updateEnabled,
    85  			}
    86  			fn(tt, &s.table[i])
    87  		})
    88  	}
    89  	if s.info != nil && s.info.needsUpdate {
    90  		s.update()
    91  	}
    92  }
    93  
    94  // T is a single test case representing an element in a table.
    95  // It embeds *testing.T, so all functions of testing.T are available.
    96  type T struct {
    97  	*testing.T
    98  
    99  	infoSrc interface{ getInfo(file string) *info }
   100  	iter    int // position in the table of the current subtest.
   101  
   102  	updateEnabled bool
   103  }
   104  
   105  func (t *T) info(file string) *info {
   106  	return t.infoSrc.getInfo(file)
   107  }
   108  
   109  func (t *T) getCallInfo() (*info, *callInfo) {
   110  	_, file, line, ok := runtime.Caller(2)
   111  	if !ok {
   112  		t.Fatalf("could not update file for test %s", t.Name())
   113  	}
   114  	// Note: it seems that sometimes the file returned by Caller
   115  	// might not be in canonical format (under Windows, it can contain
   116  	// forward slashes), so clean it.
   117  	file = filepath.Clean(file)
   118  	info := t.info(file)
   119  	return info, info.calls[token.Position{Filename: file, Line: line}]
   120  }
   121  
   122  // Equal compares two fields.
   123  //
   124  // For auto updating to work, field must reference a field in the test case
   125  // directly.
   126  func (t *T) Equal(actual, field any, msgAndArgs ...any) {
   127  	t.Helper()
   128  
   129  	switch {
   130  	case field == actual:
   131  	case t.updateEnabled:
   132  		info, ci := t.getCallInfo()
   133  		t.updateField(info, ci, actual)
   134  	case len(msgAndArgs) == 0:
   135  		_, ci := t.getCallInfo()
   136  		t.Errorf("unexpected value for field %s:\ngot:  %v;\nwant: %v", ci.fieldName, actual, field)
   137  	default:
   138  		format := msgAndArgs[0].(string) + ":\ngot:  %v;\nwant: %v"
   139  		args := append(msgAndArgs[1:], actual, field)
   140  		t.Errorf(format, args...)
   141  	}
   142  }
   143  
   144  // Update specifies whether to update the Go structs in case of discrepancies.
   145  // It overrides the default setting.
   146  func (t *T) Update(enable bool) {
   147  	t.updateEnabled = enable
   148  }
   149  
   150  // Select species which tests to run. The test may be an int, in which case
   151  // it selects the table entry to run, or a string, which is matched against
   152  // the last path of the test. An empty list runs all tests.
   153  func (t *T) Select(tests ...any) {
   154  	if len(tests) == 0 {
   155  		return
   156  	}
   157  
   158  	t.Helper()
   159  
   160  	name := t.Name()
   161  	parts := strings.Split(name, "/")
   162  
   163  	for _, n := range tests {
   164  		switch n := n.(type) {
   165  		case int:
   166  			if n == t.iter {
   167  				return
   168  			}
   169  		case string:
   170  			n = strings.ReplaceAll(n, " ", "_")
   171  			if n == parts[len(parts)-1] {
   172  				return
   173  			}
   174  		default:
   175  			panic("unexpected type passed to Select")
   176  		}
   177  	}
   178  	t.Skip("not selected")
   179  }