github.com/iqoqo/nomad@v0.11.3-0.20200911112621-d7021c74d101/drivers/shared/executor/exec_utils.go (about)

     1  package executor
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"io"
     7  	"os"
     8  	"os/exec"
     9  	"sync"
    10  	"syscall"
    11  
    12  	hclog "github.com/hashicorp/go-hclog"
    13  	"github.com/hashicorp/nomad/plugins/drivers"
    14  	dproto "github.com/hashicorp/nomad/plugins/drivers/proto"
    15  )
    16  
    17  // execHelper is a convenient wrapper for starting and executing commands, and handling their output
    18  type execHelper struct {
    19  	logger hclog.Logger
    20  
    21  	// newTerminal function creates a tty appropriate for the command
    22  	// The returned pty end of tty function is to be called after process start.
    23  	newTerminal func() (pty func() (*os.File, error), tty *os.File, err error)
    24  
    25  	// setTTY is a callback to configure the command with slave end of the tty of the terminal, when tty is enabled
    26  	setTTY func(tty *os.File) error
    27  
    28  	// setTTY is a callback to configure the command with std{in|out|err}, when tty is disabled
    29  	setIO func(stdin io.Reader, stdout, stderr io.Writer) error
    30  
    31  	// processStart starts the process, like `exec.Cmd.Start()`
    32  	processStart func() error
    33  
    34  	// processWait blocks until command terminates and returns its final state
    35  	processWait func() (*os.ProcessState, error)
    36  }
    37  
    38  func (e *execHelper) run(ctx context.Context, tty bool, stream drivers.ExecTaskStream) error {
    39  	if tty {
    40  		return e.runTTY(ctx, stream)
    41  	}
    42  	return e.runNoTTY(ctx, stream)
    43  }
    44  
    45  func (e *execHelper) runTTY(ctx context.Context, stream drivers.ExecTaskStream) error {
    46  	ptyF, tty, err := e.newTerminal()
    47  	if err != nil {
    48  		return fmt.Errorf("failed to open a tty: %v", err)
    49  	}
    50  	defer tty.Close()
    51  
    52  	if err := e.setTTY(tty); err != nil {
    53  		return fmt.Errorf("failed to set command tty: %v", err)
    54  	}
    55  	if err := e.processStart(); err != nil {
    56  		return fmt.Errorf("failed to start command: %v", err)
    57  	}
    58  
    59  	var wg sync.WaitGroup
    60  	errCh := make(chan error, 3)
    61  
    62  	pty, err := ptyF()
    63  	if err != nil {
    64  		return fmt.Errorf("failed to get pty: %v", err)
    65  	}
    66  
    67  	defer pty.Close()
    68  	wg.Add(1)
    69  	go handleStdin(e.logger, pty, stream, errCh)
    70  	// when tty is on, stdout and stderr point to the same pty so only read once
    71  	go handleStdout(e.logger, pty, &wg, stream.Send, errCh)
    72  
    73  	ps, err := e.processWait()
    74  
    75  	// force close streams to close out the stream copying goroutines
    76  	tty.Close()
    77  
    78  	// wait until we get all process output
    79  	wg.Wait()
    80  
    81  	// wait to flush out output
    82  	stream.Send(cmdExitResult(ps, err))
    83  
    84  	select {
    85  	case cerr := <-errCh:
    86  		return cerr
    87  	default:
    88  		return nil
    89  	}
    90  }
    91  
    92  func (e *execHelper) runNoTTY(ctx context.Context, stream drivers.ExecTaskStream) error {
    93  	var sendLock sync.Mutex
    94  	send := func(v *drivers.ExecTaskStreamingResponseMsg) error {
    95  		sendLock.Lock()
    96  		defer sendLock.Unlock()
    97  
    98  		return stream.Send(v)
    99  	}
   100  
   101  	stdinPr, stdinPw := io.Pipe()
   102  	stdoutPr, stdoutPw := io.Pipe()
   103  	stderrPr, stderrPw := io.Pipe()
   104  
   105  	defer stdoutPw.Close()
   106  	defer stderrPw.Close()
   107  
   108  	if err := e.setIO(stdinPr, stdoutPw, stderrPw); err != nil {
   109  		return fmt.Errorf("failed to set command io: %v", err)
   110  	}
   111  
   112  	if err := e.processStart(); err != nil {
   113  		return fmt.Errorf("failed to start command: %v", err)
   114  	}
   115  
   116  	var wg sync.WaitGroup
   117  	errCh := make(chan error, 3)
   118  
   119  	wg.Add(2)
   120  	go handleStdin(e.logger, stdinPw, stream, errCh)
   121  	go handleStdout(e.logger, stdoutPr, &wg, send, errCh)
   122  	go handleStderr(e.logger, stderrPr, &wg, send, errCh)
   123  
   124  	ps, err := e.processWait()
   125  
   126  	// force close streams to close out the stream copying goroutines
   127  	stdinPr.Close()
   128  	stdoutPw.Close()
   129  	stderrPw.Close()
   130  
   131  	// wait until we get all process output
   132  	wg.Wait()
   133  
   134  	// wait to flush out output
   135  	stream.Send(cmdExitResult(ps, err))
   136  
   137  	select {
   138  	case cerr := <-errCh:
   139  		return cerr
   140  	default:
   141  		return nil
   142  	}
   143  }
   144  func cmdExitResult(ps *os.ProcessState, err error) *drivers.ExecTaskStreamingResponseMsg {
   145  	exitCode := -1
   146  
   147  	if ps == nil {
   148  		if ee, ok := err.(*exec.ExitError); ok {
   149  			ps = ee.ProcessState
   150  		}
   151  	}
   152  
   153  	if ps == nil {
   154  		exitCode = -2
   155  	} else if status, ok := ps.Sys().(syscall.WaitStatus); ok {
   156  		exitCode = status.ExitStatus()
   157  		if status.Signaled() {
   158  			const exitSignalBase = 128
   159  			signal := int(status.Signal())
   160  			exitCode = exitSignalBase + signal
   161  		}
   162  	}
   163  
   164  	return &drivers.ExecTaskStreamingResponseMsg{
   165  		Exited: true,
   166  		Result: &dproto.ExitResult{
   167  			ExitCode: int32(exitCode),
   168  		},
   169  	}
   170  }
   171  
   172  func handleStdin(logger hclog.Logger, stdin io.WriteCloser, stream drivers.ExecTaskStream, errCh chan<- error) {
   173  	for {
   174  		m, err := stream.Recv()
   175  		if isClosedError(err) {
   176  			return
   177  		} else if err != nil {
   178  			errCh <- err
   179  			return
   180  		}
   181  
   182  		if m.Stdin != nil {
   183  			if len(m.Stdin.Data) != 0 {
   184  				_, err := stdin.Write(m.Stdin.Data)
   185  				if err != nil {
   186  					errCh <- err
   187  					return
   188  				}
   189  			}
   190  			if m.Stdin.Close {
   191  				stdin.Close()
   192  			}
   193  		} else if m.TtySize != nil {
   194  			err := setTTYSize(stdin, m.TtySize.Height, m.TtySize.Width)
   195  			if err != nil {
   196  				errCh <- fmt.Errorf("failed to resize tty: %v", err)
   197  				return
   198  			}
   199  		} else {
   200  			// ignore heartbeats
   201  		}
   202  	}
   203  }
   204  
   205  func handleStdout(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) {
   206  	defer wg.Done()
   207  
   208  	buf := make([]byte, 4096)
   209  	for {
   210  		n, err := reader.Read(buf)
   211  		// always send output first if we read something
   212  		if n > 0 {
   213  			if err := send(&drivers.ExecTaskStreamingResponseMsg{
   214  				Stdout: &dproto.ExecTaskStreamingIOOperation{
   215  					Data: buf[:n],
   216  				},
   217  			}); err != nil {
   218  				errCh <- err
   219  				return
   220  			}
   221  		}
   222  
   223  		// then process error
   224  		if isClosedError(err) {
   225  			if err := send(&drivers.ExecTaskStreamingResponseMsg{
   226  				Stdout: &dproto.ExecTaskStreamingIOOperation{
   227  					Close: true,
   228  				},
   229  			}); err != nil {
   230  				errCh <- err
   231  				return
   232  			}
   233  			return
   234  		} else if err != nil {
   235  			errCh <- err
   236  			return
   237  		}
   238  
   239  	}
   240  }
   241  
   242  func handleStderr(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) {
   243  	defer wg.Done()
   244  
   245  	buf := make([]byte, 4096)
   246  	for {
   247  		n, err := reader.Read(buf)
   248  		// always send output first if we read something
   249  		if n > 0 {
   250  			if err := send(&drivers.ExecTaskStreamingResponseMsg{
   251  				Stderr: &dproto.ExecTaskStreamingIOOperation{
   252  					Data: buf[:n],
   253  				},
   254  			}); err != nil {
   255  				errCh <- err
   256  				return
   257  			}
   258  		}
   259  
   260  		// then process error
   261  		if isClosedError(err) {
   262  			if err := send(&drivers.ExecTaskStreamingResponseMsg{
   263  				Stderr: &dproto.ExecTaskStreamingIOOperation{
   264  					Close: true,
   265  				},
   266  			}); err != nil {
   267  				errCh <- err
   268  				return
   269  			}
   270  			return
   271  		} else if err != nil {
   272  			errCh <- err
   273  			return
   274  		}
   275  
   276  	}
   277  }
   278  
   279  func isClosedError(err error) bool {
   280  	if err == nil {
   281  		return false
   282  	}
   283  
   284  	return err == io.EOF ||
   285  		err == io.ErrClosedPipe ||
   286  		isUnixEIOErr(err)
   287  }