github.com/hernad/nomad@v1.6.112/drivers/shared/executor/exec_utils.go (about)

     1  // Copyright (c) HashiCorp, Inc.
     2  // SPDX-License-Identifier: MPL-2.0
     3  
     4  package executor
     5  
     6  import (
     7  	"context"
     8  	"fmt"
     9  	"io"
    10  	"os"
    11  	"os/exec"
    12  	"sync"
    13  	"syscall"
    14  
    15  	hclog "github.com/hashicorp/go-hclog"
    16  	"github.com/hernad/nomad/plugins/drivers"
    17  	dproto "github.com/hernad/nomad/plugins/drivers/proto"
    18  )
    19  
    20  // execHelper is a convenient wrapper for starting and executing commands, and handling their output
    21  type execHelper struct {
    22  	logger hclog.Logger
    23  
    24  	// newTerminal function creates a tty appropriate for the command
    25  	// The returned pty end of tty function is to be called after process start.
    26  	newTerminal func() (pty func() (*os.File, error), tty *os.File, err error)
    27  
    28  	// setTTY is a callback to configure the command with slave end of the tty of the terminal, when tty is enabled
    29  	setTTY func(tty *os.File) error
    30  
    31  	// setTTY is a callback to configure the command with std{in|out|err}, when tty is disabled
    32  	setIO func(stdin io.Reader, stdout, stderr io.Writer) error
    33  
    34  	// processStart starts the process, like `exec.Cmd.Start()`
    35  	processStart func() error
    36  
    37  	// processWait blocks until command terminates and returns its final state
    38  	processWait func() (*os.ProcessState, error)
    39  }
    40  
    41  func (e *execHelper) run(ctx context.Context, tty bool, stream drivers.ExecTaskStream) error {
    42  	if tty {
    43  		return e.runTTY(ctx, stream)
    44  	}
    45  	return e.runNoTTY(ctx, stream)
    46  }
    47  
    48  func (e *execHelper) runTTY(ctx context.Context, stream drivers.ExecTaskStream) error {
    49  	ptyF, tty, err := e.newTerminal()
    50  	if err != nil {
    51  		return fmt.Errorf("failed to open a tty: %v", err)
    52  	}
    53  	defer tty.Close()
    54  
    55  	if err := e.setTTY(tty); err != nil {
    56  		return fmt.Errorf("failed to set command tty: %v", err)
    57  	}
    58  	if err := e.processStart(); err != nil {
    59  		return fmt.Errorf("failed to start command: %v", err)
    60  	}
    61  
    62  	var wg sync.WaitGroup
    63  	errCh := make(chan error, 3)
    64  
    65  	pty, err := ptyF()
    66  	if err != nil {
    67  		return fmt.Errorf("failed to get pty: %v", err)
    68  	}
    69  
    70  	defer pty.Close()
    71  	wg.Add(1)
    72  	go handleStdin(e.logger, pty, stream, errCh)
    73  	// when tty is on, stdout and stderr point to the same pty so only read once
    74  	go handleStdout(e.logger, pty, &wg, stream.Send, errCh)
    75  
    76  	ps, err := e.processWait()
    77  
    78  	// force close streams to close out the stream copying goroutines
    79  	tty.Close()
    80  
    81  	// wait until we get all process output
    82  	wg.Wait()
    83  
    84  	// wait to flush out output
    85  	stream.Send(cmdExitResult(ps, err))
    86  
    87  	select {
    88  	case cerr := <-errCh:
    89  		return cerr
    90  	default:
    91  		return nil
    92  	}
    93  }
    94  
    95  func (e *execHelper) runNoTTY(ctx context.Context, stream drivers.ExecTaskStream) error {
    96  	var sendLock sync.Mutex
    97  	send := func(v *drivers.ExecTaskStreamingResponseMsg) error {
    98  		sendLock.Lock()
    99  		defer sendLock.Unlock()
   100  
   101  		return stream.Send(v)
   102  	}
   103  
   104  	stdinPr, stdinPw := io.Pipe()
   105  	stdoutPr, stdoutPw := io.Pipe()
   106  	stderrPr, stderrPw := io.Pipe()
   107  
   108  	defer stdoutPw.Close()
   109  	defer stderrPw.Close()
   110  
   111  	if err := e.setIO(stdinPr, stdoutPw, stderrPw); err != nil {
   112  		return fmt.Errorf("failed to set command io: %v", err)
   113  	}
   114  
   115  	if err := e.processStart(); err != nil {
   116  		return fmt.Errorf("failed to start command: %v", err)
   117  	}
   118  
   119  	var wg sync.WaitGroup
   120  	errCh := make(chan error, 3)
   121  
   122  	wg.Add(2)
   123  	go handleStdin(e.logger, stdinPw, stream, errCh)
   124  	go handleStdout(e.logger, stdoutPr, &wg, send, errCh)
   125  	go handleStderr(e.logger, stderrPr, &wg, send, errCh)
   126  
   127  	ps, err := e.processWait()
   128  
   129  	// force close streams to close out the stream copying goroutines
   130  	stdinPr.Close()
   131  	stdoutPw.Close()
   132  	stderrPw.Close()
   133  
   134  	// wait until we get all process output
   135  	wg.Wait()
   136  
   137  	// wait to flush out output
   138  	stream.Send(cmdExitResult(ps, err))
   139  
   140  	select {
   141  	case cerr := <-errCh:
   142  		return cerr
   143  	default:
   144  		return nil
   145  	}
   146  }
   147  func cmdExitResult(ps *os.ProcessState, err error) *drivers.ExecTaskStreamingResponseMsg {
   148  	exitCode := -1
   149  
   150  	if ps == nil {
   151  		if ee, ok := err.(*exec.ExitError); ok {
   152  			ps = ee.ProcessState
   153  		}
   154  	}
   155  
   156  	if ps == nil {
   157  		exitCode = -2
   158  	} else if status, ok := ps.Sys().(syscall.WaitStatus); ok {
   159  		exitCode = status.ExitStatus()
   160  		if status.Signaled() {
   161  			const exitSignalBase = 128
   162  			signal := int(status.Signal())
   163  			exitCode = exitSignalBase + signal
   164  		}
   165  	}
   166  
   167  	return &drivers.ExecTaskStreamingResponseMsg{
   168  		Exited: true,
   169  		Result: &dproto.ExitResult{
   170  			ExitCode: int32(exitCode),
   171  		},
   172  	}
   173  }
   174  
   175  func handleStdin(logger hclog.Logger, stdin io.WriteCloser, stream drivers.ExecTaskStream, errCh chan<- error) {
   176  	for {
   177  		m, err := stream.Recv()
   178  		if isClosedError(err) {
   179  			return
   180  		} else if err != nil {
   181  			errCh <- err
   182  			return
   183  		}
   184  
   185  		if m.Stdin != nil {
   186  			if len(m.Stdin.Data) != 0 {
   187  				_, err := stdin.Write(m.Stdin.Data)
   188  				if err != nil {
   189  					errCh <- err
   190  					return
   191  				}
   192  			}
   193  			if m.Stdin.Close {
   194  				stdin.Close()
   195  			}
   196  		} else if m.TtySize != nil {
   197  			err := setTTYSize(stdin, m.TtySize.Height, m.TtySize.Width)
   198  			if err != nil {
   199  				errCh <- fmt.Errorf("failed to resize tty: %v", err)
   200  				return
   201  			}
   202  		}
   203  	}
   204  }
   205  
   206  func handleStdout(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) {
   207  	defer wg.Done()
   208  
   209  	buf := make([]byte, 4096)
   210  	for {
   211  		n, err := reader.Read(buf)
   212  		// always send output first if we read something
   213  		if n > 0 {
   214  			if err := send(&drivers.ExecTaskStreamingResponseMsg{
   215  				Stdout: &dproto.ExecTaskStreamingIOOperation{
   216  					Data: buf[:n],
   217  				},
   218  			}); err != nil {
   219  				errCh <- err
   220  				return
   221  			}
   222  		}
   223  
   224  		// then process error
   225  		if isClosedError(err) {
   226  			if err := send(&drivers.ExecTaskStreamingResponseMsg{
   227  				Stdout: &dproto.ExecTaskStreamingIOOperation{
   228  					Close: true,
   229  				},
   230  			}); err != nil {
   231  				errCh <- err
   232  				return
   233  			}
   234  			return
   235  		} else if err != nil {
   236  			errCh <- err
   237  			return
   238  		}
   239  
   240  	}
   241  }
   242  
   243  func handleStderr(logger hclog.Logger, reader io.Reader, wg *sync.WaitGroup, send func(*drivers.ExecTaskStreamingResponseMsg) error, errCh chan<- error) {
   244  	defer wg.Done()
   245  
   246  	buf := make([]byte, 4096)
   247  	for {
   248  		n, err := reader.Read(buf)
   249  		// always send output first if we read something
   250  		if n > 0 {
   251  			if err := send(&drivers.ExecTaskStreamingResponseMsg{
   252  				Stderr: &dproto.ExecTaskStreamingIOOperation{
   253  					Data: buf[:n],
   254  				},
   255  			}); err != nil {
   256  				errCh <- err
   257  				return
   258  			}
   259  		}
   260  
   261  		// then process error
   262  		if isClosedError(err) {
   263  			if err := send(&drivers.ExecTaskStreamingResponseMsg{
   264  				Stderr: &dproto.ExecTaskStreamingIOOperation{
   265  					Close: true,
   266  				},
   267  			}); err != nil {
   268  				errCh <- err
   269  				return
   270  			}
   271  			return
   272  		} else if err != nil {
   273  			errCh <- err
   274  			return
   275  		}
   276  
   277  	}
   278  }
   279  
   280  func isClosedError(err error) bool {
   281  	if err == nil {
   282  		return false
   283  	}
   284  
   285  	return err == io.EOF ||
   286  		err == io.ErrClosedPipe ||
   287  		isUnixEIOErr(err)
   288  }