github.com/sdboyer/gps@v0.16.3/cmd.go (about)

     1  package gps
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"fmt"
     7  	"os/exec"
     8  	"sync"
     9  	"time"
    10  
    11  	"github.com/Masterminds/vcs"
    12  )
    13  
    14  // monitoredCmd wraps a cmd and will keep monitoring the process until it
    15  // finishes, the provided context is canceled, or a certain amount of time has
    16  // passed and the command showed no signs of activity.
    17  type monitoredCmd struct {
    18  	cmd     *exec.Cmd
    19  	timeout time.Duration
    20  	stdout  *activityBuffer
    21  	stderr  *activityBuffer
    22  }
    23  
    24  func newMonitoredCmd(cmd *exec.Cmd, timeout time.Duration) *monitoredCmd {
    25  	stdout, stderr := newActivityBuffer(), newActivityBuffer()
    26  	cmd.Stdout, cmd.Stderr = stdout, stderr
    27  	return &monitoredCmd{
    28  		cmd:     cmd,
    29  		timeout: timeout,
    30  		stdout:  stdout,
    31  		stderr:  stderr,
    32  	}
    33  }
    34  
    35  // run will wait for the command to finish and return the error, if any. If the
    36  // command does not show any activity for more than the specified timeout the
    37  // process will be killed.
    38  func (c *monitoredCmd) run(ctx context.Context) error {
    39  	// Check for cancellation before even starting
    40  	if ctx.Err() != nil {
    41  		return ctx.Err()
    42  	}
    43  
    44  	ticker := time.NewTicker(c.timeout)
    45  	done := make(chan error, 1)
    46  	defer ticker.Stop()
    47  	go func() { done <- c.cmd.Run() }()
    48  
    49  	for {
    50  		select {
    51  		case <-ticker.C:
    52  			if c.hasTimedOut() {
    53  				// On windows it is apparently (?) possible for the process
    54  				// pointer to become nil without Run() having returned (and
    55  				// thus, passing through the done channel). Guard against this.
    56  				if c.cmd.Process != nil {
    57  					if err := c.cmd.Process.Kill(); err != nil {
    58  						return &killCmdError{err}
    59  					}
    60  				}
    61  
    62  				return &timeoutError{c.timeout}
    63  			}
    64  		case <-ctx.Done():
    65  			if c.cmd.Process != nil {
    66  				if err := c.cmd.Process.Kill(); err != nil {
    67  					return &killCmdError{err}
    68  				}
    69  			}
    70  			return ctx.Err()
    71  		case err := <-done:
    72  			return err
    73  		}
    74  	}
    75  }
    76  
    77  func (c *monitoredCmd) hasTimedOut() bool {
    78  	t := time.Now().Add(-c.timeout)
    79  	return c.stderr.lastActivity().Before(t) &&
    80  		c.stdout.lastActivity().Before(t)
    81  }
    82  
    83  func (c *monitoredCmd) combinedOutput(ctx context.Context) ([]byte, error) {
    84  	if err := c.run(ctx); err != nil {
    85  		return c.stderr.buf.Bytes(), err
    86  	}
    87  
    88  	return c.stdout.buf.Bytes(), nil
    89  }
    90  
    91  // activityBuffer is a buffer that keeps track of the last time a Write
    92  // operation was performed on it.
    93  type activityBuffer struct {
    94  	sync.Mutex
    95  	buf               *bytes.Buffer
    96  	lastActivityStamp time.Time
    97  }
    98  
    99  func newActivityBuffer() *activityBuffer {
   100  	return &activityBuffer{
   101  		buf: bytes.NewBuffer(nil),
   102  	}
   103  }
   104  
   105  func (b *activityBuffer) Write(p []byte) (int, error) {
   106  	b.Lock()
   107  	b.lastActivityStamp = time.Now()
   108  	defer b.Unlock()
   109  	return b.buf.Write(p)
   110  }
   111  
   112  func (b *activityBuffer) lastActivity() time.Time {
   113  	b.Lock()
   114  	defer b.Unlock()
   115  	return b.lastActivityStamp
   116  }
   117  
   118  type timeoutError struct {
   119  	timeout time.Duration
   120  }
   121  
   122  func (e timeoutError) Error() string {
   123  	return fmt.Sprintf("command killed after %s of no activity", e.timeout)
   124  }
   125  
   126  type killCmdError struct {
   127  	err error
   128  }
   129  
   130  func (e killCmdError) Error() string {
   131  	return fmt.Sprintf("error killing command: %s", e.err)
   132  }
   133  
   134  func runFromCwd(ctx context.Context, cmd string, args ...string) ([]byte, error) {
   135  	c := newMonitoredCmd(exec.Command(cmd, args...), 2*time.Minute)
   136  	return c.combinedOutput(ctx)
   137  }
   138  
   139  func runFromRepoDir(ctx context.Context, repo vcs.Repo, cmd string, args ...string) ([]byte, error) {
   140  	c := newMonitoredCmd(repo.CmdFromDir(cmd, args...), 2*time.Minute)
   141  	return c.combinedOutput(ctx)
   142  }