github.com/markusbkk/elvish@v0.0.0-20231204143114-91dc52438621/pkg/eval/evaltest/evaltest.go (about)

     1  // Package evaltest provides a framework for testing Elvish script.
     2  //
     3  // The entry point for the framework is the Test function, which accepts a
     4  // *testing.T and any number of test cases.
     5  //
     6  // Test cases are constructed using the That function, followed by method calls
     7  // that add additional information to it.
     8  //
     9  // Example:
    10  //
    11  //     Test(t,
    12  //         That("put x").Puts("x"),
    13  //         That("echo x").Prints("x\n"))
    14  //
    15  // If some setup is needed, use the TestWithSetup function instead.
    16  package evaltest
    17  
    18  import (
    19  	"bytes"
    20  	"os"
    21  	"reflect"
    22  	"strings"
    23  	"testing"
    24  
    25  	"github.com/markusbkk/elvish/pkg/eval"
    26  	"github.com/markusbkk/elvish/pkg/eval/vals"
    27  	"github.com/markusbkk/elvish/pkg/parse"
    28  	"github.com/markusbkk/elvish/pkg/testutil"
    29  )
    30  
    31  // Case is a test case that can be used in Test.
    32  type Case struct {
    33  	codes  []string
    34  	setup  func(ev *eval.Evaler)
    35  	verify func(t *testing.T)
    36  	want   result
    37  }
    38  
    39  type result struct {
    40  	ValueOut  []interface{}
    41  	BytesOut  []byte
    42  	StderrOut []byte
    43  
    44  	CompilationError error
    45  	Exception        error
    46  }
    47  
    48  // That returns a new Case with the specified source code. Multiple arguments
    49  // are joined with newlines. To specify multiple pieces of code that are
    50  // executed separately, use the Then method to append code pieces.
    51  //
    52  // When combined with subsequent method calls, a test case reads like English.
    53  // For example, a test for the fact that "put x" puts "x" reads:
    54  //
    55  //     That("put x").Puts("x")
    56  func That(lines ...string) Case {
    57  	return Case{codes: []string{strings.Join(lines, "\n")}}
    58  }
    59  
    60  // Then returns a new Case that executes the given code in addition. Multiple
    61  // arguments are joined with newlines.
    62  func (c Case) Then(lines ...string) Case {
    63  	c.codes = append(c.codes, strings.Join(lines, "\n"))
    64  	return c
    65  }
    66  
    67  // Then returns a new Case with the given setup function executed on the Evaler
    68  // before the code is executed.
    69  func (c Case) WithSetup(f func(*eval.Evaler)) Case {
    70  	c.setup = f
    71  	return c
    72  }
    73  
    74  // DoesNothing returns t unchanged. It is useful to mark tests that don't have
    75  // any side effects, for example:
    76  //
    77  //     That("nop").DoesNothing()
    78  func (c Case) DoesNothing() Case {
    79  	return c
    80  }
    81  
    82  // Puts returns an altered Case that runs an additional verification function.
    83  func (c Case) Passes(f func(t *testing.T)) Case {
    84  	c.verify = f
    85  	return c
    86  }
    87  
    88  // Puts returns an altered Case that requires the source code to produce the
    89  // specified values in the value channel when evaluated.
    90  func (c Case) Puts(vs ...interface{}) Case {
    91  	c.want.ValueOut = vs
    92  	return c
    93  }
    94  
    95  // Prints returns an altered Case that requires the source code to produce the
    96  // specified output in the byte pipe when evaluated.
    97  func (c Case) Prints(s string) Case {
    98  	c.want.BytesOut = []byte(s)
    99  	return c
   100  }
   101  
   102  // PrintsStderrWith returns an altered Case that requires the stderr output to
   103  // contain the given text.
   104  func (c Case) PrintsStderrWith(s string) Case {
   105  	c.want.StderrOut = []byte(s)
   106  	return c
   107  }
   108  
   109  // Throws returns an altered Case that requires the source code to throw an
   110  // exception with the given reason. The reason supports special matcher values
   111  // constructed by functions like ErrorWithMessage.
   112  //
   113  // If at least one stacktrace string is given, the exception must also have a
   114  // stacktrace matching the given source fragments, frame by frame (innermost
   115  // frame first). If no stacktrace string is given, the stack trace of the
   116  // exception is not checked.
   117  func (c Case) Throws(reason error, stacks ...string) Case {
   118  	c.want.Exception = exc{reason, stacks}
   119  	return c
   120  }
   121  
   122  // DoesNotCompile returns an altered Case that requires the source code to fail
   123  // compilation.
   124  func (c Case) DoesNotCompile() Case {
   125  	c.want.CompilationError = anyError{}
   126  	return c
   127  }
   128  
   129  // Test runs test cases. For each test case, a new Evaler is created with
   130  // NewEvaler.
   131  func Test(t *testing.T, tests ...Case) {
   132  	t.Helper()
   133  	TestWithSetup(t, func(*eval.Evaler) {}, tests...)
   134  }
   135  
   136  // TestWithSetup runs test cases. For each test case, a new Evaler is created
   137  // with NewEvaler and passed to the setup function.
   138  func TestWithSetup(t *testing.T, setup func(*eval.Evaler), tests ...Case) {
   139  	t.Helper()
   140  	for _, tt := range tests {
   141  		t.Run(strings.Join(tt.codes, "\n"), func(t *testing.T) {
   142  			t.Helper()
   143  			ev := eval.NewEvaler()
   144  			setup(ev)
   145  			if tt.setup != nil {
   146  				tt.setup(ev)
   147  			}
   148  
   149  			r := evalAndCollect(t, ev, tt.codes)
   150  
   151  			if tt.verify != nil {
   152  				tt.verify(t)
   153  			}
   154  			if !matchOut(tt.want.ValueOut, r.ValueOut) {
   155  				t.Errorf("got value out %v, want %v",
   156  					reprs(r.ValueOut), reprs(tt.want.ValueOut))
   157  			}
   158  			if !bytes.Equal(tt.want.BytesOut, r.BytesOut) {
   159  				t.Errorf("got bytes out %q, want %q", r.BytesOut, tt.want.BytesOut)
   160  			}
   161  			if !bytes.Contains(r.StderrOut, tt.want.StderrOut) {
   162  				t.Errorf("got stderr out %q, want %q", r.StderrOut, tt.want.StderrOut)
   163  			}
   164  			if !matchErr(tt.want.CompilationError, r.CompilationError) {
   165  				t.Errorf("got compilation error %v, want %v",
   166  					r.CompilationError, tt.want.CompilationError)
   167  			}
   168  			if !matchErr(tt.want.Exception, r.Exception) {
   169  				t.Errorf("unexpected exception")
   170  				if exc, ok := r.Exception.(eval.Exception); ok {
   171  					// For an eval.Exception report the type of the underlying error.
   172  					t.Logf("got: %T: %v", exc.Reason(), exc)
   173  					t.Logf("stack trace: %#v", getStackTexts(exc.StackTrace()))
   174  				} else {
   175  					t.Logf("got: %T: %v", r.Exception, r.Exception)
   176  				}
   177  				t.Errorf("want: %v", tt.want.Exception)
   178  			}
   179  		})
   180  	}
   181  }
   182  
   183  func evalAndCollect(t *testing.T, ev *eval.Evaler, texts []string) result {
   184  	var r result
   185  
   186  	port1, collect1 := capturePort()
   187  	port2, collect2 := capturePort()
   188  	ports := []*eval.Port{eval.DummyInputPort, port1, port2}
   189  
   190  	for _, text := range texts {
   191  		err := ev.Eval(parse.Source{Name: "[test]", Code: text},
   192  			eval.EvalCfg{Ports: ports, Interrupt: eval.ListenInterrupts})
   193  
   194  		if parse.GetError(err) != nil {
   195  			t.Fatalf("Parse(%q) error: %s", text, err)
   196  		} else if eval.GetCompilationError(err) != nil {
   197  			// NOTE: If multiple code pieces have compilation errors, only the
   198  			// last one compilation error is saved.
   199  			r.CompilationError = err
   200  		} else if err != nil {
   201  			// NOTE: If multiple code pieces throw exceptions, only the last one
   202  			// is saved.
   203  			r.Exception = err
   204  		}
   205  	}
   206  
   207  	r.ValueOut, r.BytesOut = collect1()
   208  	_, r.StderrOut = collect2()
   209  	return r
   210  }
   211  
   212  // Like eval.CapturePort, but captures values and bytes separately. Also panics
   213  // if it cannot create a pipe.
   214  func capturePort() (*eval.Port, func() ([]interface{}, []byte)) {
   215  	var values []interface{}
   216  	var bytes []byte
   217  	port, done, err := eval.PipePort(
   218  		func(ch <-chan interface{}) {
   219  			for v := range ch {
   220  				values = append(values, v)
   221  			}
   222  		},
   223  		func(r *os.File) {
   224  			bytes = testutil.MustReadAllAndClose(r)
   225  		})
   226  	if err != nil {
   227  		panic(err)
   228  	}
   229  	return port, func() ([]interface{}, []byte) {
   230  		done()
   231  		return values, bytes
   232  	}
   233  }
   234  
   235  func matchOut(want, got []interface{}) bool {
   236  	if len(got) != len(want) {
   237  		return false
   238  	}
   239  	for i := range got {
   240  		if !match(got[i], want[i]) {
   241  			return false
   242  		}
   243  	}
   244  	return true
   245  }
   246  
   247  func match(got, want interface{}) bool {
   248  	switch got := got.(type) {
   249  	case float64:
   250  		// Special-case float64 to correctly handle NaN and support
   251  		// approximate comparison.
   252  		switch want := want.(type) {
   253  		case float64:
   254  			return matchFloat64(got, want, 0)
   255  		case Approximately:
   256  			return matchFloat64(got, want.F, ApproximatelyThreshold)
   257  		}
   258  	case string:
   259  		switch want := want.(type) {
   260  		case MatchingRegexp:
   261  			return matchRegexp(want.Pattern, got)
   262  		}
   263  	}
   264  	return vals.Equal(got, want)
   265  }
   266  
   267  func reprs(values []interface{}) []string {
   268  	s := make([]string, len(values))
   269  	for i, v := range values {
   270  		s[i] = vals.ReprPlain(v)
   271  	}
   272  	return s
   273  }
   274  
   275  func matchErr(want, got error) bool {
   276  	if want == nil {
   277  		return got == nil
   278  	}
   279  	if matcher, ok := want.(errorMatcher); ok {
   280  		return matcher.matchError(got)
   281  	}
   282  	return reflect.DeepEqual(want, got)
   283  }