github.com/neatlab/neatio@v1.7.3-0.20220425043230-d903e92fcc75/internal/cmdtest/test_cmd.go (about)

     1  package cmdtest
     2  
     3  import (
     4  	"bufio"
     5  	"bytes"
     6  	"fmt"
     7  	"io"
     8  	"io/ioutil"
     9  	"os"
    10  	"os/exec"
    11  	"regexp"
    12  	"strings"
    13  	"sync"
    14  	"testing"
    15  	"text/template"
    16  	"time"
    17  
    18  	"github.com/docker/docker/pkg/reexec"
    19  )
    20  
    21  func NewTestCmd(t *testing.T, data interface{}) *TestCmd {
    22  	return &TestCmd{T: t, Data: data}
    23  }
    24  
    25  type TestCmd struct {
    26  	*testing.T
    27  
    28  	Func    template.FuncMap
    29  	Data    interface{}
    30  	Cleanup func()
    31  
    32  	cmd    *exec.Cmd
    33  	stdout *bufio.Reader
    34  	stdin  io.WriteCloser
    35  	stderr *testlogger
    36  }
    37  
    38  func (tt *TestCmd) Run(name string, args ...string) {
    39  	tt.stderr = &testlogger{t: tt.T}
    40  	tt.cmd = &exec.Cmd{
    41  		Path:   reexec.Self(),
    42  		Args:   append([]string{name}, args...),
    43  		Stderr: tt.stderr,
    44  	}
    45  	stdout, err := tt.cmd.StdoutPipe()
    46  	if err != nil {
    47  		tt.Fatal(err)
    48  	}
    49  	tt.stdout = bufio.NewReader(stdout)
    50  	if tt.stdin, err = tt.cmd.StdinPipe(); err != nil {
    51  		tt.Fatal(err)
    52  	}
    53  	if err := tt.cmd.Start(); err != nil {
    54  		tt.Fatal(err)
    55  	}
    56  }
    57  
    58  func (tt *TestCmd) InputLine(s string) string {
    59  	io.WriteString(tt.stdin, s+"\n")
    60  	return ""
    61  }
    62  
    63  func (tt *TestCmd) SetTemplateFunc(name string, fn interface{}) {
    64  	if tt.Func == nil {
    65  		tt.Func = make(map[string]interface{})
    66  	}
    67  	tt.Func[name] = fn
    68  }
    69  
    70  func (tt *TestCmd) Expect(tplsource string) {
    71  
    72  	tpl := template.Must(template.New("").Funcs(tt.Func).Parse(tplsource))
    73  	wantbuf := new(bytes.Buffer)
    74  	if err := tpl.Execute(wantbuf, tt.Data); err != nil {
    75  		panic(err)
    76  	}
    77  
    78  	want := bytes.TrimPrefix(wantbuf.Bytes(), []byte("\n"))
    79  	if err := tt.matchExactOutput(want); err != nil {
    80  		tt.Fatal(err)
    81  	}
    82  	tt.Logf("Matched stdout text:\n%s", want)
    83  }
    84  
    85  func (tt *TestCmd) matchExactOutput(want []byte) error {
    86  	buf := make([]byte, len(want))
    87  	n := 0
    88  	tt.withKillTimeout(func() { n, _ = io.ReadFull(tt.stdout, buf) })
    89  	buf = buf[:n]
    90  	if n < len(want) || !bytes.Equal(buf, want) {
    91  
    92  		buf = append(buf, make([]byte, tt.stdout.Buffered())...)
    93  		tt.stdout.Read(buf[n:])
    94  
    95  		for i := 0; i < n; i++ {
    96  			if want[i] != buf[i] {
    97  				return fmt.Errorf("Output mismatch at ā—Š:\n---------------- (stdout text)\n%sā—Š%s\n---------------- (expected text)\n%s",
    98  					buf[:i], buf[i:n], want)
    99  			}
   100  		}
   101  		if n < len(want) {
   102  			return fmt.Errorf("Not enough output, got until ā—Š:\n---------------- (stdout text)\n%s\n---------------- (expected text)\n%sā—Š%s",
   103  				buf, want[:n], want[n:])
   104  		}
   105  	}
   106  	return nil
   107  }
   108  
   109  func (tt *TestCmd) ExpectRegexp(regex string) (*regexp.Regexp, []string) {
   110  	regex = strings.TrimPrefix(regex, "\n")
   111  	var (
   112  		re      = regexp.MustCompile(regex)
   113  		rtee    = &runeTee{in: tt.stdout}
   114  		matches []int
   115  	)
   116  	tt.withKillTimeout(func() { matches = re.FindReaderSubmatchIndex(rtee) })
   117  	output := rtee.buf.Bytes()
   118  	if matches == nil {
   119  		tt.Fatalf("Output did not match:\n---------------- (stdout text)\n%s\n---------------- (regular expression)\n%s",
   120  			output, regex)
   121  		return re, nil
   122  	}
   123  	tt.Logf("Matched stdout text:\n%s", output)
   124  	var submatches []string
   125  	for i := 0; i < len(matches); i += 2 {
   126  		submatch := string(output[matches[i]:matches[i+1]])
   127  		submatches = append(submatches, submatch)
   128  	}
   129  	return re, submatches
   130  }
   131  
   132  func (tt *TestCmd) ExpectExit() {
   133  	var output []byte
   134  	tt.withKillTimeout(func() {
   135  		output, _ = ioutil.ReadAll(tt.stdout)
   136  	})
   137  	tt.WaitExit()
   138  	if tt.Cleanup != nil {
   139  		tt.Cleanup()
   140  	}
   141  	if len(output) > 0 {
   142  		tt.Errorf("Unmatched stdout text:\n%s", output)
   143  	}
   144  }
   145  
   146  func (tt *TestCmd) WaitExit() {
   147  	tt.cmd.Wait()
   148  }
   149  
   150  func (tt *TestCmd) Interrupt() {
   151  	tt.cmd.Process.Signal(os.Interrupt)
   152  }
   153  
   154  func (tt *TestCmd) StderrText() string {
   155  	tt.stderr.mu.Lock()
   156  	defer tt.stderr.mu.Unlock()
   157  	return tt.stderr.buf.String()
   158  }
   159  
   160  func (tt *TestCmd) CloseStdin() {
   161  	tt.stdin.Close()
   162  }
   163  
   164  func (tt *TestCmd) Kill() {
   165  	tt.cmd.Process.Kill()
   166  	if tt.Cleanup != nil {
   167  		tt.Cleanup()
   168  	}
   169  }
   170  
   171  func (tt *TestCmd) withKillTimeout(fn func()) {
   172  	timeout := time.AfterFunc(5*time.Second, func() {
   173  		tt.Log("killing the side process (timeout)")
   174  		tt.Kill()
   175  	})
   176  	defer timeout.Stop()
   177  	fn()
   178  }
   179  
   180  type testlogger struct {
   181  	t   *testing.T
   182  	mu  sync.Mutex
   183  	buf bytes.Buffer
   184  }
   185  
   186  func (tl *testlogger) Write(b []byte) (n int, err error) {
   187  	lines := bytes.Split(b, []byte("\n"))
   188  	for _, line := range lines {
   189  		if len(line) > 0 {
   190  			tl.t.Logf("(stderr) %s", line)
   191  		}
   192  	}
   193  	tl.mu.Lock()
   194  	tl.buf.Write(b)
   195  	tl.mu.Unlock()
   196  	return len(b), err
   197  }
   198  
   199  type runeTee struct {
   200  	in interface {
   201  		io.Reader
   202  		io.ByteReader
   203  		io.RuneReader
   204  	}
   205  	buf bytes.Buffer
   206  }
   207  
   208  func (rtee *runeTee) Read(b []byte) (n int, err error) {
   209  	n, err = rtee.in.Read(b)
   210  	rtee.buf.Write(b[:n])
   211  	return n, err
   212  }
   213  
   214  func (rtee *runeTee) ReadRune() (r rune, size int, err error) {
   215  	r, size, err = rtee.in.ReadRune()
   216  	if err == nil {
   217  		rtee.buf.WriteRune(r)
   218  	}
   219  	return r, size, err
   220  }
   221  
   222  func (rtee *runeTee) ReadByte() (b byte, err error) {
   223  	b, err = rtee.in.ReadByte()
   224  	if err == nil {
   225  		rtee.buf.WriteByte(b)
   226  	}
   227  	return b, err
   228  }