gvisor.dev/gvisor@v0.0.0-20240520182842-f9d4d51c7e0f/pkg/test/testutil/sh.go (about)

     1  // Copyright 2020 The gVisor Authors.
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package testutil
    16  
    17  import (
    18  	"bytes"
    19  	"context"
    20  	"fmt"
    21  	"io"
    22  	"os"
    23  	"os/exec"
    24  	"strings"
    25  	"time"
    26  
    27  	"github.com/kr/pty"
    28  	"golang.org/x/sys/unix"
    29  )
    30  
    31  // Prompt is used as shell prompt.
    32  // It is meant to be unique enough to not be seen in command outputs.
    33  const Prompt = "PROMPT> "
    34  
    35  // Simplistic shell string escape.
    36  func shellEscape(s string) string {
    37  	// specialChars is used to determine whether s needs quoting at all.
    38  	const specialChars = "\\'\"`${[|&;<>()*?! \t\n"
    39  	// If s needs quoting, escapedChars is the set of characters that are
    40  	// escaped with a backslash.
    41  	const escapedChars = "\\\"$`"
    42  	if len(s) == 0 {
    43  		return "''"
    44  	}
    45  	if !strings.ContainsAny(s, specialChars) {
    46  		return s
    47  	}
    48  	var b bytes.Buffer
    49  	b.WriteString("\"")
    50  	for _, c := range s {
    51  		if strings.ContainsAny(string(c), escapedChars) {
    52  			b.WriteString("\\")
    53  		}
    54  		b.WriteRune(c)
    55  	}
    56  	b.WriteString("\"")
    57  	return b.String()
    58  }
    59  
    60  type byteOrError struct {
    61  	b   byte
    62  	err error
    63  }
    64  
    65  // Shell manages a /bin/sh invocation with convenience functions to handle I/O.
    66  // The shell is run in its own interactive TTY and should present its prompt.
    67  type Shell struct {
    68  	// cmd is a reference to the underlying sh process.
    69  	cmd *exec.Cmd
    70  	// cmdFinished is closed when cmd exits.
    71  	cmdFinished chan struct{}
    72  
    73  	// echo is whether the shell will echo input back to us.
    74  	// This helps setting expectations of getting feedback of written bytes.
    75  	echo bool
    76  	// Control characters we expect to see in the shell.
    77  	controlCharIntr string
    78  	controlCharEOF  string
    79  
    80  	// ptyMaster and ptyReplica are the TTY pair associated with the shell.
    81  	ptyMaster  *os.File
    82  	ptyReplica *os.File
    83  	// readCh is a channel where everything read from ptyMaster is written.
    84  	readCh chan byteOrError
    85  
    86  	// logger is used for logging. It may be nil.
    87  	logger Logger
    88  }
    89  
    90  // cleanup kills the shell process and closes the TTY.
    91  // Users of this library get a reference to this function with NewShell.
    92  func (s *Shell) cleanup() {
    93  	s.logf("cleanup", "Shell cleanup started.")
    94  	if s.cmd.ProcessState == nil {
    95  		if err := s.cmd.Process.Kill(); err != nil {
    96  			s.logf("cleanup", "cannot kill shell process: %v", err)
    97  		}
    98  		// We don't log the error returned by Wait because the monitorExit
    99  		// goroutine will already do so.
   100  		s.cmd.Wait()
   101  	}
   102  	s.ptyReplica.Close()
   103  	s.ptyMaster.Close()
   104  	// Wait for monitorExit goroutine to write exit status to the debug log.
   105  	<-s.cmdFinished
   106  	// Empty out everything in the readCh, but don't wait too long for it.
   107  	var extraBytes bytes.Buffer
   108  	unreadTimeout := time.After(100 * time.Millisecond)
   109  unreadLoop:
   110  	for {
   111  		select {
   112  		case r, ok := <-s.readCh:
   113  			if !ok {
   114  				break unreadLoop
   115  			} else if r.err == nil {
   116  				extraBytes.WriteByte(r.b)
   117  			}
   118  		case <-unreadTimeout:
   119  			break unreadLoop
   120  		}
   121  	}
   122  	if extraBytes.Len() > 0 {
   123  		s.logIO("unread", extraBytes.Bytes(), nil)
   124  	}
   125  	s.logf("cleanup", "Shell cleanup complete.")
   126  }
   127  
   128  // logIO logs byte I/O to both standard logging and the test log, if provided.
   129  func (s *Shell) logIO(prefix string, b []byte, err error) {
   130  	var sb strings.Builder
   131  	if len(b) > 0 {
   132  		sb.WriteString(fmt.Sprintf("%q", b))
   133  	} else {
   134  		sb.WriteString("(nothing)")
   135  	}
   136  	if err != nil {
   137  		sb.WriteString(fmt.Sprintf(" [error: %v]", err))
   138  	}
   139  	s.logf(prefix, "%s", sb.String())
   140  }
   141  
   142  // logf logs something to both standard logging and the test log, if provided.
   143  func (s *Shell) logf(prefix, format string, values ...any) {
   144  	if s.logger != nil {
   145  		s.logger.Logf("[%s] %s", prefix, fmt.Sprintf(format, values...))
   146  	}
   147  }
   148  
   149  // monitorExit waits for the shell process to exit and logs the exit result.
   150  func (s *Shell) monitorExit() {
   151  	if err := s.cmd.Wait(); err != nil {
   152  		s.logf("cmd", "shell process terminated: %v", err)
   153  	} else {
   154  		s.logf("cmd", "shell process terminated successfully")
   155  	}
   156  	close(s.cmdFinished)
   157  }
   158  
   159  // reader continuously reads the shell output and populates readCh.
   160  func (s *Shell) reader(ctx context.Context) {
   161  	b := make([]byte, 4096)
   162  	defer close(s.readCh)
   163  	for {
   164  		select {
   165  		case <-s.cmdFinished:
   166  			// Shell process terminated; stop trying to read.
   167  			return
   168  		case <-ctx.Done():
   169  			// Shell process will also have terminated in this case;
   170  			// stop trying to read.
   171  			// We don't print an error here because doing so would print this in the
   172  			// normal case where the context passed to NewShell is canceled at the
   173  			// end of a successful test.
   174  			return
   175  		default:
   176  			// Shell still running, try reading.
   177  		}
   178  		if got, err := s.ptyMaster.Read(b); err != nil {
   179  			s.readCh <- byteOrError{err: err}
   180  			if err == io.EOF {
   181  				return
   182  			}
   183  		} else {
   184  			for i := 0; i < got; i++ {
   185  				s.readCh <- byteOrError{b: b[i]}
   186  			}
   187  		}
   188  	}
   189  }
   190  
   191  // readByte reads a single byte, respecting the context.
   192  func (s *Shell) readByte(ctx context.Context) (byte, error) {
   193  	select {
   194  	case <-ctx.Done():
   195  		return 0, ctx.Err()
   196  	case r := <-s.readCh:
   197  		return r.b, r.err
   198  	}
   199  }
   200  
   201  // readLoop reads as many bytes as possible until the context expires, b is
   202  // full, or a short time passes. It returns how many bytes it has successfully
   203  // read.
   204  func (s *Shell) readLoop(ctx context.Context, b []byte) (int, error) {
   205  	soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second)
   206  	defer soonCancel()
   207  	var i int
   208  	for i = 0; i < len(b) && soonCtx.Err() == nil; i++ {
   209  		next, err := s.readByte(soonCtx)
   210  		if err != nil {
   211  			if i > 0 {
   212  				s.logIO("read", b[:i-1], err)
   213  			} else {
   214  				s.logIO("read", nil, err)
   215  			}
   216  			return i, err
   217  		}
   218  		b[i] = next
   219  	}
   220  	s.logIO("read", b[:i], soonCtx.Err())
   221  	return i, soonCtx.Err()
   222  }
   223  
   224  // readLine reads a single line. Strips out all \r characters for convenience.
   225  // Upon error, it will still return what it has read so far.
   226  // It will also exit quickly if the line content it has read so far (without a
   227  // line break) matches `prompt`.
   228  func (s *Shell) readLine(ctx context.Context, prompt string) ([]byte, error) {
   229  	soonCtx, soonCancel := context.WithTimeout(ctx, 5*time.Second)
   230  	defer soonCancel()
   231  	var lineData bytes.Buffer
   232  	var b byte
   233  	var err error
   234  	for soonCtx.Err() == nil && b != '\n' {
   235  		b, err = s.readByte(soonCtx)
   236  		if err != nil {
   237  			data := lineData.Bytes()
   238  			s.logIO("read", data, err)
   239  			return data, err
   240  		}
   241  		if b != '\r' {
   242  			lineData.WriteByte(b)
   243  		}
   244  		if bytes.Equal(lineData.Bytes(), []byte(prompt)) {
   245  			// Assume that there will not be any further output if we get the prompt.
   246  			// This avoids waiting for the read deadline just to read the prompt.
   247  			break
   248  		}
   249  	}
   250  	data := lineData.Bytes()
   251  	s.logIO("read", data, soonCtx.Err())
   252  	return data, soonCtx.Err()
   253  }
   254  
   255  // Expect verifies that the next `len(want)` bytes we read match `want`.
   256  func (s *Shell) Expect(ctx context.Context, want []byte) error {
   257  	errPrefix := fmt.Sprintf("want(%q)", want)
   258  	b := make([]byte, len(want))
   259  	got, err := s.readLoop(ctx, b)
   260  	if err != nil {
   261  		if ctx.Err() != nil {
   262  			return fmt.Errorf("%s: context done (%w), got: %q", errPrefix, err, b[:got])
   263  		}
   264  		return fmt.Errorf("%s: %w", errPrefix, err)
   265  	}
   266  	if got < len(want) {
   267  		return fmt.Errorf("%s: short read (read %d bytes, expected %d): %q", errPrefix, got, len(want), b[:got])
   268  	}
   269  	if !bytes.Equal(b, want) {
   270  		return fmt.Errorf("got %q want %q", b, want)
   271  	}
   272  	return nil
   273  }
   274  
   275  // ExpectString verifies that the next `len(want)` bytes we read match `want`.
   276  func (s *Shell) ExpectString(ctx context.Context, want string) error {
   277  	return s.Expect(ctx, []byte(want))
   278  }
   279  
   280  // ExpectPrompt verifies that the next few bytes we read are the shell prompt.
   281  func (s *Shell) ExpectPrompt(ctx context.Context) error {
   282  	return s.ExpectString(ctx, Prompt)
   283  }
   284  
   285  // ExpectEmptyLine verifies that the next few bytes we read are an empty line,
   286  // as defined by any number of carriage or line break characters.
   287  func (s *Shell) ExpectEmptyLine(ctx context.Context) error {
   288  	line, err := s.readLine(ctx, Prompt)
   289  	if err != nil {
   290  		return fmt.Errorf("cannot read line: %w", err)
   291  	}
   292  	if strings.Trim(string(line), "\r\n") != "" {
   293  		return fmt.Errorf("line was not empty: %q", line)
   294  	}
   295  	return nil
   296  }
   297  
   298  // ExpectLine verifies that the next `len(want)` bytes we read match `want`,
   299  // followed by carriage returns or newline characters.
   300  func (s *Shell) ExpectLine(ctx context.Context, want string) error {
   301  	if err := s.ExpectString(ctx, want); err != nil {
   302  		return err
   303  	}
   304  	if err := s.ExpectEmptyLine(ctx); err != nil {
   305  		return fmt.Errorf("ExpectLine(%q): no line break: %w", want, err)
   306  	}
   307  	return nil
   308  }
   309  
   310  // Write writes `b` to the shell and verifies that all of them get written.
   311  func (s *Shell) Write(b []byte) error {
   312  	written, err := s.ptyMaster.Write(b)
   313  	s.logIO("write", b[:written], err)
   314  	if err != nil {
   315  		return fmt.Errorf("write(%q): %w", b, err)
   316  	}
   317  	if written != len(b) {
   318  		return fmt.Errorf("write(%q): wrote %d of %d bytes (%q)", b, written, len(b), b[:written])
   319  	}
   320  	return nil
   321  }
   322  
   323  // WriteLine writes `line` (to which \n will be appended) to the shell.
   324  // If the shell is in `echo` mode, it will also check that we got these bytes
   325  // back to read.
   326  func (s *Shell) WriteLine(ctx context.Context, line string) error {
   327  	if err := s.Write([]byte(line + "\n")); err != nil {
   328  		return err
   329  	}
   330  	if s.echo {
   331  		// We expect to see everything we've typed.
   332  		if err := s.ExpectLine(ctx, line); err != nil {
   333  			return fmt.Errorf("echo: %w", err)
   334  		}
   335  	}
   336  	return nil
   337  }
   338  
   339  // StartCommand is a convenience wrapper for WriteLine that mimics entering a
   340  // command line and pressing Enter. It does some basic shell argument escaping.
   341  func (s *Shell) StartCommand(ctx context.Context, cmd ...string) error {
   342  	escaped := make([]string, len(cmd))
   343  	for i, arg := range cmd {
   344  		escaped[i] = shellEscape(arg)
   345  	}
   346  	return s.WriteLine(ctx, strings.Join(escaped, " "))
   347  }
   348  
   349  // GetCommandOutput gets all following bytes until the prompt is encountered.
   350  // This is useful for matching the output of a command.
   351  // All \r are removed for ease of matching.
   352  func (s *Shell) GetCommandOutput(ctx context.Context) ([]byte, error) {
   353  	return s.ReadUntil(ctx, Prompt)
   354  }
   355  
   356  // ReadUntil gets all following bytes until a certain line is encountered.
   357  // This final line is not returned as part of the output, but everything before
   358  // it (including the \n) is included.
   359  // This is useful for matching the output of a command.
   360  // All \r are removed for ease of matching.
   361  func (s *Shell) ReadUntil(ctx context.Context, finalLine string) ([]byte, error) {
   362  	var output bytes.Buffer
   363  	for ctx.Err() == nil {
   364  		line, err := s.readLine(ctx, finalLine)
   365  		if err != nil {
   366  			return nil, err
   367  		}
   368  		if bytes.Equal(line, []byte(finalLine)) {
   369  			break
   370  		}
   371  		// readLine ensures that `line` either matches `finalLine` or contains \n.
   372  		// Thus we can be confident that `line` has a \n here.
   373  		output.Write(line)
   374  	}
   375  	return output.Bytes(), ctx.Err()
   376  }
   377  
   378  // RunCommand is a convenience wrapper for StartCommand + GetCommandOutput.
   379  func (s *Shell) RunCommand(ctx context.Context, cmd ...string) ([]byte, error) {
   380  	if err := s.StartCommand(ctx, cmd...); err != nil {
   381  		return nil, err
   382  	}
   383  	return s.GetCommandOutput(ctx)
   384  }
   385  
   386  // RefreshSTTY interprets output from `stty -a` to check whether we are in echo
   387  // mode and other settings.
   388  // It will assume that any line matching `expectPrompt` means the end of
   389  // the `stty -a` output.
   390  // Why do this rather than using `tcgets`? Because this function can be used in
   391  // conjunction with sub-shell processes that can allocate their own TTYs.
   392  func (s *Shell) RefreshSTTY(ctx context.Context, expectPrompt string) error {
   393  	// Temporarily assume we will not get any output.
   394  	// If echo is actually on, we'll get the "stty -a" line as if it was command
   395  	// output. This is OK because we parse the output generously.
   396  	s.echo = false
   397  	if err := s.WriteLine(ctx, "stty -a"); err != nil {
   398  		return fmt.Errorf("could not run `stty -a`: %w", err)
   399  	}
   400  	sttyOutput, err := s.ReadUntil(ctx, expectPrompt)
   401  	if err != nil {
   402  		return fmt.Errorf("cannot get `stty -a` output: %w", err)
   403  	}
   404  
   405  	// Set default control characters in case we can't see them in the output.
   406  	s.controlCharIntr = "^C"
   407  	s.controlCharEOF = "^D"
   408  	// stty output has two general notations:
   409  	// `a = b;` (for control characters), and `option` vs `-option` (for boolean
   410  	// options). We parse both kinds here.
   411  	// For `a = b;`, `controlChar` contains `a`, and `previousToken` is used to
   412  	// set `controlChar` to `previousToken` when we see an "=" token.
   413  	var previousToken, controlChar string
   414  	for _, token := range strings.Fields(string(sttyOutput)) {
   415  		if controlChar != "" {
   416  			value := strings.TrimSuffix(token, ";")
   417  			switch controlChar {
   418  			case "intr":
   419  				s.controlCharIntr = value
   420  			case "eof":
   421  				s.controlCharEOF = value
   422  			}
   423  			controlChar = ""
   424  		} else {
   425  			switch token {
   426  			case "=":
   427  				controlChar = previousToken
   428  			case "-echo":
   429  				s.echo = false
   430  			case "echo":
   431  				s.echo = true
   432  			}
   433  		}
   434  		previousToken = token
   435  	}
   436  	s.logf("stty", "refreshed settings: echo=%v, intr=%q, eof=%q", s.echo, s.controlCharIntr, s.controlCharEOF)
   437  	return nil
   438  }
   439  
   440  // sendControlCode sends `code` to the shell and expects to see `repr`.
   441  // If `expectLinebreak` is true, it also expects to see a linebreak.
   442  func (s *Shell) sendControlCode(ctx context.Context, code byte, repr string, expectLinebreak bool) error {
   443  	if err := s.Write([]byte{code}); err != nil {
   444  		return fmt.Errorf("cannot send %q: %w", code, err)
   445  	}
   446  	if err := s.ExpectString(ctx, repr); err != nil {
   447  		return fmt.Errorf("did not see %s: %w", repr, err)
   448  	}
   449  	if expectLinebreak {
   450  		if err := s.ExpectEmptyLine(ctx); err != nil {
   451  			return fmt.Errorf("linebreak after %s: %v", repr, err)
   452  		}
   453  	}
   454  	return nil
   455  }
   456  
   457  // SendInterrupt sends the \x03 (Ctrl+C) control character to the shell.
   458  func (s *Shell) SendInterrupt(ctx context.Context, expectLinebreak bool) error {
   459  	return s.sendControlCode(ctx, 0x03, s.controlCharIntr, expectLinebreak)
   460  }
   461  
   462  // SendEOF sends the \x04 (Ctrl+D) control character to the shell.
   463  func (s *Shell) SendEOF(ctx context.Context, expectLinebreak bool) error {
   464  	return s.sendControlCode(ctx, 0x04, s.controlCharEOF, expectLinebreak)
   465  }
   466  
   467  // NewShell returns a new managed sh process along with a cleanup function.
   468  // The caller is expected to call this function once it no longer needs the
   469  // shell.
   470  // The optional passed-in logger will be used for logging.
   471  func NewShell(ctx context.Context, logger Logger) (*Shell, func(), error) {
   472  	ptyMaster, ptyReplica, err := pty.Open()
   473  	if err != nil {
   474  		return nil, nil, fmt.Errorf("cannot create PTY: %w", err)
   475  	}
   476  	cmd := exec.CommandContext(ctx, "/bin/sh", "--noprofile", "--norc", "-i")
   477  	cmd.Stdin = ptyReplica
   478  	cmd.Stdout = ptyReplica
   479  	cmd.Stderr = ptyReplica
   480  	cmd.SysProcAttr = &unix.SysProcAttr{
   481  		Setsid:  true,
   482  		Setctty: true,
   483  		Ctty:    0,
   484  	}
   485  	cmd.Env = append(cmd.Env, fmt.Sprintf("PS1=%s", Prompt))
   486  	if err := cmd.Start(); err != nil {
   487  		return nil, nil, fmt.Errorf("cannot start shell: %w", err)
   488  	}
   489  	s := &Shell{
   490  		cmd:         cmd,
   491  		cmdFinished: make(chan struct{}),
   492  		ptyMaster:   ptyMaster,
   493  		ptyReplica:  ptyReplica,
   494  		readCh:      make(chan byteOrError, 1<<20),
   495  		logger:      logger,
   496  	}
   497  	s.logf("creation", "Shell spawned.")
   498  	go s.monitorExit()
   499  	go s.reader(ctx)
   500  	setupCtx, setupCancel := context.WithTimeout(ctx, 5*time.Second)
   501  	defer setupCancel()
   502  	// We expect to see the prompt immediately on startup,
   503  	// since the shell is started in interactive mode.
   504  	if err := s.ExpectPrompt(setupCtx); err != nil {
   505  		s.cleanup()
   506  		return nil, nil, fmt.Errorf("did not get initial prompt: %w", err)
   507  	}
   508  	s.logf("creation", "Initial prompt observed.")
   509  	// Get initial TTY settings.
   510  	if err := s.RefreshSTTY(setupCtx, Prompt); err != nil {
   511  		s.cleanup()
   512  		return nil, nil, fmt.Errorf("cannot get initial STTY settings: %w", err)
   513  	}
   514  	return s, s.cleanup, nil
   515  }