github.com/elves/elvish@v0.15.0/pkg/eval/evaltest/evaltest.go (about)

     1  // Framework for testing Elvish script. This file does not have a _test.go
     2  // suffix so that it can be used from other packages that also want to test the
     3  // modules they implement (e.g. edit: and re:).
     4  //
     5  // The entry point for the framework is the Test function, which accepts a
     6  // *testing.T and a variadic number of test cases. Test cases are constructed
     7  // using the That function followed by methods that add constraints on the test
     8  // case. Overall, a test looks like:
     9  //
    10  //     Test(t,
    11  //         That("put x").Puts("x"),
    12  //         That("echo x").Prints("x\n"))
    13  //
    14  // If some setup is needed, use the TestWithSetup function instead.
    15  
    16  package evaltest
    17  
    18  import (
    19  	"bytes"
    20  	"os"
    21  	"reflect"
    22  	"strings"
    23  	"testing"
    24  
    25  	"github.com/elves/elvish/pkg/eval"
    26  	"github.com/elves/elvish/pkg/eval/vals"
    27  	"github.com/elves/elvish/pkg/parse"
    28  	"github.com/elves/elvish/pkg/testutil"
    29  )
    30  
    31  // TestCase is a test case for Test.
    32  type TestCase struct {
    33  	codes []string
    34  	want  Result
    35  }
    36  
    37  type Result struct {
    38  	ValueOut  []interface{}
    39  	BytesOut  []byte
    40  	StderrOut []byte
    41  
    42  	CompilationError error
    43  	Exception        error
    44  }
    45  
    46  // The following functions and methods are used to build Test structs. They are
    47  // supposed to read like English, so a test that "put x" should put "x" reads:
    48  //
    49  // That("put x").Puts("x")
    50  
    51  // That returns a new Test with the specified source code. Multiple arguments
    52  // are joined with newlines. To specify multiple pieces of code that are
    53  // executed separately, use the Then method to append code pieces.
    54  func That(lines ...string) TestCase {
    55  	return TestCase{codes: []string{strings.Join(lines, "\n")}}
    56  }
    57  
    58  // Then returns a new Test that executes the given code in addition. Multiple
    59  // arguments are joined with newlines.
    60  func (t TestCase) Then(lines ...string) TestCase {
    61  	t.codes = append(t.codes, strings.Join(lines, "\n"))
    62  	return t
    63  }
    64  
    65  // DoesNothing returns t unchanged. It is used to mark that a piece of code
    66  // should simply does nothing. In particular, it shouldn't have any output and
    67  // does not error.
    68  func (t TestCase) DoesNothing() TestCase {
    69  	return t
    70  }
    71  
    72  // Puts returns an altered TestCase that requires the source code to produce the
    73  // specified values in the value channel when evaluated.
    74  func (t TestCase) Puts(vs ...interface{}) TestCase {
    75  	t.want.ValueOut = vs
    76  	return t
    77  }
    78  
    79  // Prints returns an altered TestCase that requires the source code to produce
    80  // the specified output in the byte pipe when evaluated.
    81  func (t TestCase) Prints(s string) TestCase {
    82  	t.want.BytesOut = []byte(s)
    83  	return t
    84  }
    85  
    86  // PrintsStderrWith returns an altered TestCase that requires the stderr
    87  // output to contain the given text.
    88  func (t TestCase) PrintsStderrWith(s string) TestCase {
    89  	t.want.StderrOut = []byte(s)
    90  	return t
    91  }
    92  
    93  // Throws returns an altered TestCase that requires the source code to throw an
    94  // exception with the given reason. The reason supports special matcher values
    95  // constructed by functions like ErrorWithMessage.
    96  //
    97  // If at least one stacktrace string is given, the exception must also have a
    98  // stacktrace matching the given source fragments, frame by frame (innermost
    99  // frame first). If no stacktrace string is given, the stack trace of the
   100  // exception is not checked.
   101  func (t TestCase) Throws(reason error, stacks ...string) TestCase {
   102  	t.want.Exception = exc{reason, stacks}
   103  	return t
   104  }
   105  
   106  // DoesNotCompile returns an altered TestCase that requires the source code to
   107  // fail compilation.
   108  func (t TestCase) DoesNotCompile() TestCase {
   109  	t.want.CompilationError = anyError{}
   110  	return t
   111  }
   112  
   113  // Test runs test cases. For each test case, a new Evaler is created with
   114  // NewEvaler.
   115  func Test(t *testing.T, tests ...TestCase) {
   116  	t.Helper()
   117  	TestWithSetup(t, func(*eval.Evaler) {}, tests...)
   118  }
   119  
   120  // TestWithSetup runs test cases. For each test case, a new Evaler is created
   121  // with NewEvaler and passed to the setup function.
   122  func TestWithSetup(t *testing.T, setup func(*eval.Evaler), tests ...TestCase) {
   123  	t.Helper()
   124  	for _, tt := range tests {
   125  		t.Run(strings.Join(tt.codes, "\n"), func(t *testing.T) {
   126  			t.Helper()
   127  			ev := eval.NewEvaler()
   128  			setup(ev)
   129  
   130  			r := evalAndCollect(t, ev, tt.codes)
   131  
   132  			if !matchOut(tt.want.ValueOut, r.ValueOut) {
   133  				t.Errorf("got value out %v, want %v",
   134  					reprs(r.ValueOut), reprs(tt.want.ValueOut))
   135  			}
   136  			if !bytes.Equal(tt.want.BytesOut, r.BytesOut) {
   137  				t.Errorf("got bytes out %q, want %q", r.BytesOut, tt.want.BytesOut)
   138  			}
   139  			if !bytes.Contains(r.StderrOut, tt.want.StderrOut) {
   140  				t.Errorf("got stderr out %q, want %q", r.StderrOut, tt.want.StderrOut)
   141  			}
   142  			if !matchErr(tt.want.CompilationError, r.CompilationError) {
   143  				t.Errorf("got compilation error %v, want %v",
   144  					r.CompilationError, tt.want.CompilationError)
   145  			}
   146  			if !matchErr(tt.want.Exception, r.Exception) {
   147  				t.Errorf("unexpected exception")
   148  				t.Logf("got: %v", r.Exception)
   149  				if exc, ok := r.Exception.(eval.Exception); ok {
   150  					t.Logf("stack trace: %#v", getStackTexts(exc.StackTrace()))
   151  				}
   152  				t.Errorf("want: %v", tt.want.Exception)
   153  			}
   154  		})
   155  	}
   156  }
   157  
   158  func evalAndCollect(t *testing.T, ev *eval.Evaler, texts []string) Result {
   159  	var r Result
   160  
   161  	port1, collect1 := capturePort()
   162  	port2, collect2 := capturePort()
   163  	ports := []*eval.Port{eval.DummyInputPort, port1, port2}
   164  
   165  	for _, text := range texts {
   166  		err := ev.Eval(parse.Source{Name: "[test]", Code: text},
   167  			eval.EvalCfg{Ports: ports, Interrupt: eval.ListenInterrupts})
   168  
   169  		if parse.GetError(err) != nil {
   170  			t.Fatalf("Parse(%q) error: %s", text, err)
   171  		} else if eval.GetCompilationError(err) != nil {
   172  			// NOTE: If multiple code pieces have compilation errors, only the
   173  			// last one compilation error is saved.
   174  			r.CompilationError = err
   175  		} else if err != nil {
   176  			// NOTE: If multiple code pieces throw exceptions, only the last one
   177  			// is saved.
   178  			r.Exception = err
   179  		}
   180  	}
   181  
   182  	r.ValueOut, r.BytesOut = collect1()
   183  	_, r.StderrOut = collect2()
   184  	return r
   185  }
   186  
   187  // Like eval.CapturePort, but captures values and bytes separately. Also panics
   188  // if it cannot create a pipe.
   189  func capturePort() (*eval.Port, func() ([]interface{}, []byte)) {
   190  	var values []interface{}
   191  	var bytes []byte
   192  	port, done, err := eval.PipePort(
   193  		func(ch <-chan interface{}) {
   194  			for v := range ch {
   195  				values = append(values, v)
   196  			}
   197  		},
   198  		func(r *os.File) {
   199  			bytes = testutil.MustReadAllAndClose(r)
   200  		})
   201  	if err != nil {
   202  		panic(err)
   203  	}
   204  	return port, func() ([]interface{}, []byte) {
   205  		done()
   206  		return values, bytes
   207  	}
   208  }
   209  
   210  func matchOut(want, got []interface{}) bool {
   211  	if len(got) == 0 && len(want) == 0 {
   212  		return true
   213  	}
   214  	if len(got) != len(want) {
   215  		return false
   216  	}
   217  	for i := range got {
   218  		switch g := got[i].(type) {
   219  		case float64:
   220  			// Special-case float64 to correctly handle NaN and support
   221  			// approximate comparison.
   222  			switch w := want[i].(type) {
   223  			case float64:
   224  				if !matchFloat64(g, w, 0) {
   225  					return false
   226  				}
   227  			case Approximately:
   228  				if !matchFloat64(g, w.F, ApproximatelyThreshold) {
   229  					return false
   230  				}
   231  			default:
   232  				return false
   233  			}
   234  		default:
   235  			if !vals.Equal(got[i], want[i]) {
   236  				return false
   237  			}
   238  		}
   239  	}
   240  	return true
   241  }
   242  
   243  func reprs(values []interface{}) []string {
   244  	s := make([]string, len(values))
   245  	for i, v := range values {
   246  		s[i] = vals.Repr(v, vals.NoPretty)
   247  	}
   248  	return s
   249  }
   250  
   251  func matchErr(want, got error) bool {
   252  	if want == nil {
   253  		return got == nil
   254  	}
   255  	if matcher, ok := want.(errorMatcher); ok {
   256  		return matcher.matchError(got)
   257  	}
   258  	return reflect.DeepEqual(want, got)
   259  }