github.com/oweisse/u-root@v0.0.0-20181109060735-d005ad25fef1/cmds/elvish/eval/testutils.go (about)

     1  // Common testing utilities. This file does not file a _test.go suffix so that
     2  // it can be used from other packages that also want to test the modules they
     3  // implement (e.g. edit: and re:).
     4  
     5  package eval
     6  
     7  import (
     8  	"bytes"
     9  	"errors"
    10  	"fmt"
    11  	"os"
    12  	"reflect"
    13  
    14  	"github.com/u-root/u-root/cmds/elvish/eval/vals"
    15  	"github.com/u-root/u-root/cmds/elvish/parse"
    16  )
    17  
    18  // Test is a test case for TestEval.
    19  type Test struct {
    20  	text string
    21  	want
    22  }
    23  
    24  type want struct {
    25  	out      []interface{}
    26  	bytesOut []byte
    27  	err      error
    28  }
    29  
    30  // A special value for want.err to indicate that any error, as long as not nil,
    31  // is OK
    32  var errAny = errors.New("any error")
    33  
    34  // The following functions and methods are used to build Test structs. They are
    35  // supposed to read like English, so a test that "put x" should put "x" reads:
    36  //
    37  // That("put x").Puts("x")
    38  
    39  // That returns a new Test with the specified source code.
    40  func That(text string) Test {
    41  	return Test{text: text}
    42  }
    43  
    44  // DoesNothing returns t unchanged. It is used to mark that a piece of code
    45  // should simply does nothing. In particular, it shouldn't have any output and
    46  // does not error.
    47  func (t Test) DoesNothing() Test {
    48  	return t
    49  }
    50  
    51  // Puts returns an altered Test that requires the source code to produce the
    52  // specified values in the value channel when evaluated.
    53  func (t Test) Puts(vs ...interface{}) Test {
    54  	t.want.out = vs
    55  	return t
    56  }
    57  
    58  // Puts returns an altered Test that requires the source code to produce the
    59  // specified strings in the value channel when evaluated.
    60  func (t Test) PutsStrings(ss []string) Test {
    61  	t.want.out = make([]interface{}, len(ss))
    62  	for i, s := range ss {
    63  		t.want.out[i] = s
    64  	}
    65  	return t
    66  }
    67  
    68  // Prints returns an altered test that requires the source code to produce
    69  // the specified output in the byte pipe when evaluated.
    70  func (t Test) Prints(s string) Test {
    71  	t.want.bytesOut = []byte(s)
    72  	return t
    73  }
    74  
    75  // ErrorsWith returns an altered Test that requires the source code to result in
    76  // the specified error when evaluted.
    77  func (t Test) ErrorsWith(err error) Test {
    78  	t.want.err = err
    79  	return t
    80  }
    81  
    82  // Errors returns an altered Test that requires the source code to result in any
    83  // error when evaluated.
    84  func (t Test) Errors() Test {
    85  	return t.ErrorsWith(errAny)
    86  }
    87  
    88  // RunTests runs test cases. For each test case, a new Evaler is made by calling
    89  // makeEvaler.
    90  func RunTests(evalTests []Test, makeEvaler func() *Evaler) error {
    91  	for _, tt := range evalTests {
    92  		// fmt.Printf("eval %q\n", tt.text)
    93  
    94  		ev := makeEvaler()
    95  		defer ev.Close()
    96  		out, bytesOut, err := evalAndCollect(ev, []string{tt.text}, len(tt.want.out))
    97  
    98  		first := true
    99  		errorf := func(format string, args ...interface{}) error {
   100  			if first {
   101  				first = false
   102  				return fmt.Errorf("eval(%q) fails:", tt.text)
   103  			}
   104  			return fmt.Errorf("  "+format, args...)
   105  		}
   106  
   107  		if !matchOut(tt.want.out, out) {
   108  			if err := errorf("got out=%v, want %v", out, tt.want.out); err != nil {
   109  				return err
   110  			}
   111  		}
   112  		if !bytes.Equal(tt.want.bytesOut, bytesOut) {
   113  			if err := errorf("got bytesOut=%q, want %q", bytesOut, tt.want.bytesOut); err != nil {
   114  				return err
   115  			}
   116  		}
   117  		if !matchErr(tt.want.err, err) {
   118  			if err := errorf("got err=%v, want %v", err, tt.want.err); err != nil {
   119  				return err
   120  			}
   121  		}
   122  	}
   123  	return nil
   124  }
   125  
   126  func evalAndCollect(ev *Evaler, texts []string, chsize int) ([]interface{}, []byte, error) {
   127  	// Collect byte output
   128  	bytesOut := []byte{}
   129  	pr, pw, _ := os.Pipe()
   130  	bytesDone := make(chan struct{})
   131  	go func() {
   132  		for {
   133  			var buf [64]byte
   134  			nr, err := pr.Read(buf[:])
   135  			bytesOut = append(bytesOut, buf[:nr]...)
   136  			if err != nil {
   137  				break
   138  			}
   139  		}
   140  		close(bytesDone)
   141  	}()
   142  
   143  	// Channel output
   144  	outs := []interface{}{}
   145  
   146  	// Eval error. Only that of the last text is saved.
   147  	var ex error
   148  
   149  	for i, text := range texts {
   150  		name := fmt.Sprintf("test%d.elv", i)
   151  		src := NewScriptSource(name, name, text)
   152  
   153  		op, err := mustParseAndCompile(ev, src)
   154  		if err != nil {
   155  			return nil, nil, err
   156  		}
   157  
   158  		outCh := make(chan interface{}, chsize)
   159  		outDone := make(chan struct{})
   160  		go func() {
   161  			for v := range outCh {
   162  				outs = append(outs, v)
   163  			}
   164  			close(outDone)
   165  		}()
   166  
   167  		ports := []*Port{
   168  			{File: os.Stdin, Chan: ClosedChan},
   169  			{File: pw, Chan: outCh},
   170  			{File: os.Stderr, Chan: BlackholeChan},
   171  		}
   172  
   173  		ex = ev.eval(op, ports, src)
   174  		close(outCh)
   175  		<-outDone
   176  	}
   177  
   178  	pw.Close()
   179  	<-bytesDone
   180  	pr.Close()
   181  
   182  	return outs, bytesOut, ex
   183  }
   184  
   185  func mustParseAndCompile(ev *Evaler, src *Source) (Op, error) {
   186  	n, err := parse.Parse(src.name, src.code)
   187  	if err != nil {
   188  		return Op{}, fmt.Errorf("Parse(%q) error: %s", src.code, err)
   189  	}
   190  	op, err := ev.Compile(n, src)
   191  	if err != nil {
   192  		return Op{}, fmt.Errorf("Compile(Parse(%q)) error: %s", src.code, err)
   193  	}
   194  	return op, nil
   195  }
   196  
   197  func matchOut(want, got []interface{}) bool {
   198  	if len(got) == 0 && len(want) == 0 {
   199  		return true
   200  	}
   201  	if len(got) != len(want) {
   202  		return false
   203  	}
   204  	for i := range got {
   205  		if !vals.Equal(got[i], want[i]) {
   206  			return false
   207  		}
   208  	}
   209  	return true
   210  }
   211  
   212  func matchErr(want, got error) bool {
   213  	if got == nil {
   214  		return want == nil
   215  	}
   216  	return want == errAny || reflect.DeepEqual(got.(*Exception).Cause, want)
   217  }
   218  
   219  // compareValues compares two slices, using equals for each element.
   220  func compareSlice(wantValues, gotValues []interface{}) error {
   221  	if len(wantValues) != len(gotValues) {
   222  		return fmt.Errorf("want %d values, got %d",
   223  			len(wantValues), len(gotValues))
   224  	}
   225  	for i, want := range wantValues {
   226  		if !vals.Equal(want, gotValues[i]) {
   227  			return fmt.Errorf("want [%d] = %s, got %s", i, want, gotValues[i])
   228  		}
   229  	}
   230  	return nil
   231  }