github.com/zuoyebang/bitalostable@v1.0.1-0.20240229032404-e3b99a834294/internal/datadriven/datadriven.go (about)

     1  // Copyright 2018 The Cockroach Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
    12  // implied. See the License for the specific language governing
    13  // permissions and limitations under the License.
    14  
    15  package datadriven // import "github.com/zuoyebang/bitalostable/internal/datadriven"
    16  
    17  import (
    18  	"bufio"
    19  	"flag"
    20  	"fmt"
    21  	"io"
    22  	"io/ioutil"
    23  	"os"
    24  	"path/filepath"
    25  	"strconv"
    26  	"strings"
    27  	"testing"
    28  
    29  	"github.com/cockroachdb/errors"
    30  )
    31  
    32  var (
    33  	rewriteTestFiles = flag.Bool(
    34  		"rewrite", false,
    35  		"ignore the expected results and rewrite the test files with the actual results from this "+
    36  			"run. Used to update tests when a change affects many cases; please verify the testfile "+
    37  			"diffs carefully!",
    38  	)
    39  )
    40  
    41  // RunTest invokes a data-driven test. The test cases are contained in a
    42  // separate test file and are dynamically loaded, parsed, and executed by this
    43  // testing framework. By convention, test files are typically located in a
    44  // sub-directory called "testdata". Each test file has the following format:
    45  //
    46  //	<command>[,<command>...] [arg | arg=val | arg=(val1, val2, ...)]...
    47  //	<input to the command>
    48  //	----
    49  //	<expected results>
    50  //
    51  // The command input can contain blank lines. However, by default, the expected
    52  // results cannot contain blank lines. This alternate syntax allows the use of
    53  // blank lines:
    54  //
    55  //	<command>[,<command>...] [arg | arg=val | arg=(val1, val2, ...)]...
    56  //	<input to the command>
    57  //	----
    58  //	----
    59  //	<expected results>
    60  //
    61  //	<more expected results>
    62  //	----
    63  //	----
    64  //
    65  // To execute data-driven tests, pass the path of the test file as well as a
    66  // function which can interpret and execute whatever commands are present in
    67  // the test file. The framework invokes the function, passing it information
    68  // about the test case in a TestData struct. The function then returns the
    69  // actual results of the case, which this function compares with the expected
    70  // results, and either succeeds or fails the test.
    71  func RunTest(t *testing.T, path string, f func(d *TestData) string) {
    72  	t.Helper()
    73  	file, err := os.OpenFile(path, os.O_RDWR, 0644 /* irrelevant */)
    74  	if err != nil {
    75  		t.Fatal(err)
    76  	}
    77  	defer func() {
    78  		_ = file.Close()
    79  	}()
    80  
    81  	runTestInternal(t, path, file, f, *rewriteTestFiles)
    82  }
    83  
    84  // RunTestFromString is a version of RunTest which takes the contents of a test
    85  // directly.
    86  func RunTestFromString(t *testing.T, input string, f func(d *TestData) string) {
    87  	t.Helper()
    88  	runTestInternal(t, "<string>" /* optionalPath */, strings.NewReader(input), f, *rewriteTestFiles)
    89  }
    90  
    91  func runTestInternal(
    92  	t *testing.T, sourceName string, reader io.Reader, f func(d *TestData) string, rewrite bool,
    93  ) {
    94  	t.Helper()
    95  
    96  	r := newTestDataReader(t, sourceName, reader, rewrite)
    97  	for r.Next(t) {
    98  		d := &r.data
    99  		actual := func() string {
   100  			defer func() {
   101  				if r := recover(); r != nil {
   102  					fmt.Printf("\npanic during %s:\n%s\n", d.Pos, d.Input)
   103  					panic(r)
   104  				}
   105  			}()
   106  			s := f(d)
   107  			if n := len(s); n > 0 && s[n-1] != '\n' {
   108  				s += "\n"
   109  			}
   110  			return s
   111  		}()
   112  
   113  		if r.rewrite != nil {
   114  			r.emit("----")
   115  			if hasBlankLine(actual) {
   116  				r.emit("----")
   117  				r.rewrite.WriteString(actual)
   118  				r.emit("----")
   119  				r.emit("----")
   120  			} else {
   121  				r.emit(actual)
   122  			}
   123  		} else if d.Expected != actual {
   124  			t.Fatalf("\n%s: %s\nexpected:\n%s\nfound:\n%s", d.Pos, d.Input, d.Expected, actual)
   125  		} else if testing.Verbose() {
   126  			input := d.Input
   127  			if input == "" {
   128  				input = "<no input to command>"
   129  			}
   130  			// TODO(tbg): it's awkward to reproduce the args, but it would be helpful.
   131  			fmt.Printf("\n%s:\n%s [%d args]\n%s\n----\n%s", d.Pos, d.Cmd, len(d.CmdArgs), input, actual)
   132  		}
   133  	}
   134  
   135  	if r.rewrite != nil {
   136  		data := r.rewrite.Bytes()
   137  		if l := len(data); l > 2 && data[l-1] == '\n' && data[l-2] == '\n' {
   138  			data = data[:l-1]
   139  		}
   140  		if dest, ok := reader.(*os.File); ok {
   141  			if _, err := dest.WriteAt(data, 0); err != nil {
   142  				t.Fatal(err)
   143  			}
   144  			if err := dest.Truncate(int64(len(data))); err != nil {
   145  				t.Fatal(err)
   146  			}
   147  			if err := dest.Sync(); err != nil {
   148  				t.Fatal(err)
   149  			}
   150  		} else {
   151  			t.Logf("input is not a file; rewritten output is:\n%s", data)
   152  		}
   153  	}
   154  }
   155  
   156  // Walk goes through all the files in a subdirectory, creating subtests to match
   157  // the file hierarchy; for each "leaf" file, the given function is called.
   158  //
   159  // This can be used in conjunction with RunTest. For example:
   160  //
   161  //	 datadriven.Walk(t, path, func (t *testing.T, path string) {
   162  //	   // initialize per-test state
   163  //	   datadriven.RunTest(t, path, func (d *datadriven.TestData) {
   164  //	    // ...
   165  //	   }
   166  //	 }
   167  //
   168  //	Files:
   169  //	  testdata/typing
   170  //	  testdata/logprops/scan
   171  //	  testdata/logprops/select
   172  //
   173  //	If path is "testdata/typing", the function is called once and no subtests
   174  //	care created.
   175  //
   176  //	If path is "testdata/logprops", the function is called two times, in
   177  //	separate subtests /scan, /select.
   178  //
   179  //	If path is "testdata", the function is called three times, in subtest
   180  //	hierarchy /typing, /logprops/scan, /logprops/select.
   181  func Walk(t *testing.T, path string, f func(t *testing.T, path string)) {
   182  	finfo, err := os.Stat(path)
   183  	if err != nil {
   184  		t.Fatal(err)
   185  	}
   186  	if !finfo.IsDir() {
   187  		f(t, path)
   188  		return
   189  	}
   190  	files, err := ioutil.ReadDir(path)
   191  	if err != nil {
   192  		t.Fatal(err)
   193  	}
   194  	for _, file := range files {
   195  		t.Run(file.Name(), func(t *testing.T) {
   196  			Walk(t, filepath.Join(path, file.Name()), f)
   197  		})
   198  	}
   199  }
   200  
   201  // TestData contains information about one data-driven test case that was
   202  // parsed from the test file.
   203  type TestData struct {
   204  	Pos string // reader and line number
   205  
   206  	// Cmd is the first string on the directive line (up to the first whitespace).
   207  	Cmd string
   208  
   209  	CmdArgs []CmdArg
   210  
   211  	Input    string
   212  	Expected string
   213  }
   214  
   215  // ScanArgs looks up the first CmdArg matching the given key and scans it into
   216  // the given destinations in order. If the arg does not exist, the number of
   217  // destinations does not match that of the arguments, or a destination can not
   218  // be populated from its matching value, a fatal error results.
   219  //
   220  // # For example, for a TestData originating from
   221  //
   222  // cmd arg1=50 arg2=yoruba arg3=(50, 50, 50)
   223  //
   224  // the following would be valid:
   225  //
   226  // var i1, i2, i3, i4 int
   227  // var s string
   228  // td.ScanArgs(t, "arg1", &i1)
   229  // td.ScanArgs(t, "arg2", &s)
   230  // td.ScanArgs(t, "arg3", &i2, &i3, &i4)
   231  func (td *TestData) ScanArgs(t *testing.T, key string, dests ...interface{}) {
   232  	t.Helper()
   233  	arg := td.findArg(key)
   234  	if arg == nil {
   235  		t.Fatalf("missing argument: %s", key)
   236  	}
   237  	err := arg.scan(dests...)
   238  	if err != nil {
   239  		t.Fatal(err)
   240  	}
   241  }
   242  
   243  // HasArg determines if `key` appears in CmdArgs.
   244  func (td *TestData) HasArg(key string) bool {
   245  	return td.findArg(key) != nil
   246  }
   247  
   248  func (td *TestData) findArg(key string) *CmdArg {
   249  	for i := range td.CmdArgs {
   250  		if td.CmdArgs[i].Key == key {
   251  			return &td.CmdArgs[i]
   252  		}
   253  	}
   254  	return nil
   255  }
   256  
   257  // CmdArg contains information about an argument on the directive line. An
   258  // argument is specified in one of the following forms:
   259  //   - argument
   260  //   - argument=value
   261  //   - argument=(values, ...)
   262  type CmdArg struct {
   263  	Key  string
   264  	Vals []string
   265  }
   266  
   267  func (arg CmdArg) String() string {
   268  	switch len(arg.Vals) {
   269  	case 0:
   270  		return arg.Key
   271  
   272  	case 1:
   273  		return fmt.Sprintf("%s=%s", arg.Key, arg.Vals[0])
   274  
   275  	default:
   276  		return fmt.Sprintf("%s=(%s)", arg.Key, strings.Join(arg.Vals, ", "))
   277  	}
   278  }
   279  
   280  func (arg CmdArg) scan(dests ...interface{}) error {
   281  	if len(dests) != len(arg.Vals) {
   282  		return errors.Errorf("%s: got %d destinations, but %d values", arg.Key, len(dests), len(arg.Vals))
   283  	}
   284  
   285  	for i := range dests {
   286  		val := arg.Vals[i]
   287  		switch dest := dests[i].(type) {
   288  		case *string:
   289  			*dest = val
   290  		case *int:
   291  			n, err := strconv.ParseInt(val, 10, 64)
   292  			if err != nil {
   293  				return err
   294  			}
   295  			*dest = int(n) // assume 64bit ints
   296  		case *uint64:
   297  			n, err := strconv.ParseUint(val, 10, 64)
   298  			if err != nil {
   299  				return err
   300  			}
   301  			*dest = n
   302  		case *bool:
   303  			b, err := strconv.ParseBool(val)
   304  			if err != nil {
   305  				return err
   306  			}
   307  			*dest = b
   308  		default:
   309  			return errors.Errorf("unsupported type %T for destination #%d (might be easy to add it)", dest, i+1)
   310  		}
   311  	}
   312  	return nil
   313  }
   314  
   315  // Fatalf wraps a fatal testing error with test file position information, so
   316  // that it's easy to locate the source of the error.
   317  func (td TestData) Fatalf(tb testing.TB, format string, args ...interface{}) {
   318  	tb.Helper()
   319  	tb.Fatalf("%s: %s", td.Pos, fmt.Sprintf(format, args...))
   320  }
   321  
   322  func hasBlankLine(s string) bool {
   323  	scanner := bufio.NewScanner(strings.NewReader(s))
   324  	for scanner.Scan() {
   325  		if strings.TrimSpace(scanner.Text()) == "" {
   326  			return true
   327  		}
   328  	}
   329  	return false
   330  }