go.uber.org/yarpc@v1.72.1/internal/service-test/cmd.go (about)

     1  // Copyright (c) 2022 Uber Technologies, Inc.
     2  //
     3  // Permission is hereby granted, free of charge, to any person obtaining a copy
     4  // of this software and associated documentation files (the "Software"), to deal
     5  // in the Software without restriction, including without limitation the rights
     6  // to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
     7  // copies of the Software, and to permit persons to whom the Software is
     8  // furnished to do so, subject to the following conditions:
     9  //
    10  // The above copyright notice and this permission notice shall be included in
    11  // all copies or substantial portions of the Software.
    12  //
    13  // THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
    14  // IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
    15  // FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
    16  // AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
    17  // LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
    18  // OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
    19  // THE SOFTWARE.
    20  
    21  package main
    22  
    23  import (
    24  	"bytes"
    25  	"fmt"
    26  	"log"
    27  	"os"
    28  	"os/exec"
    29  	"strings"
    30  	"sync"
    31  	"syscall"
    32  	"time"
    33  
    34  	"github.com/mattn/go-shellwords"
    35  )
    36  
    37  type cmd struct {
    38  	cmd           *exec.Cmd
    39  	sleep         time.Duration
    40  	output        string
    41  	debug         bool
    42  	stdout        *bytes.Buffer
    43  	stderr        *bytes.Buffer
    44  	flushedStdout bool
    45  	flushedStderr bool
    46  	finished      bool
    47  	lock          sync.Mutex
    48  }
    49  
    50  func newCmd(cmdConfig *cmdConfig, dir string, debug bool) (*cmd, error) {
    51  	parser := shellwords.NewParser()
    52  	parser.ParseEnv = true
    53  	args, err := parser.Parse(cmdConfig.Command)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  	if len(args) == 0 {
    58  		return nil, fmt.Errorf("command evaulated to empty: %s", cmdConfig.Command)
    59  	}
    60  	execCmd := exec.Command(args[0], args[1:]...)
    61  	execCmd.Dir = dir
    62  	// https://medium.com/@felixge/killing-a-child-process-and-all-of-its-children-in-go-54079af94773
    63  	execCmd.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
    64  	cmd := &cmd{
    65  		execCmd,
    66  		time.Duration(cmdConfig.SleepMs) * time.Millisecond,
    67  		cmdConfig.Output,
    68  		debug,
    69  		bytes.NewBuffer(nil),
    70  		bytes.NewBuffer(nil),
    71  		false,
    72  		false,
    73  		false,
    74  		sync.Mutex{},
    75  	}
    76  	if cmdConfig.Input != "" {
    77  		execCmd.Stdin = strings.NewReader(cmdConfig.Input)
    78  	}
    79  	execCmd.Stdout = cmd.stdout
    80  	execCmd.Stderr = cmd.stderr
    81  	return cmd, nil
    82  }
    83  
    84  func (c *cmd) Start() error {
    85  	c.debugPrintf("starting")
    86  	if err := c.cmd.Start(); err != nil {
    87  		return c.wrapError("failed to start", err)
    88  	}
    89  	c.debugPrintf("started")
    90  	if c.sleep != 0 {
    91  		c.debugPrintf("sleeping")
    92  		<-time.After(c.sleep)
    93  		c.debugPrintf("done sleeping")
    94  	}
    95  	return nil
    96  }
    97  
    98  func (c *cmd) Wait() error {
    99  	c.debugPrintf("waiting")
   100  	if err := c.cmd.Wait(); err != nil {
   101  		return c.wrapError("failed", err)
   102  	}
   103  	c.debugPrintf("finished")
   104  	c.lock.Lock()
   105  	defer c.lock.Unlock()
   106  	c.finished = true
   107  	return nil
   108  }
   109  
   110  func (c *cmd) Validate() error {
   111  	if c.output == "" {
   112  		return nil
   113  	}
   114  	output := cleanOutput(c.stdout.String())
   115  	expectedOutput := cleanOutput(c.output)
   116  	if output != expectedOutput {
   117  		return c.wrapError("validation failed", fmt.Errorf("expected\n%s\ngot\n%s", expectedOutput, output))
   118  	}
   119  	return nil
   120  }
   121  
   122  func (c *cmd) Clean(suppressStdout bool) {
   123  	c.Kill()
   124  	if !suppressStdout {
   125  		c.FlushStdout()
   126  	}
   127  	c.FlushStderr()
   128  }
   129  
   130  func (c *cmd) Kill() {
   131  	c.lock.Lock()
   132  	defer c.lock.Unlock()
   133  	if c.finished {
   134  		return
   135  	}
   136  	if c.cmd.Process != nil {
   137  		c.debugPrintf("killing")
   138  		// https://medium.com/@felixge/killing-a-child-process-and-all-of-its-children-in-go-54079af94773
   139  		_ = syscall.Kill(-c.cmd.Process.Pid, syscall.SIGKILL)
   140  		c.finished = true
   141  	}
   142  }
   143  
   144  func (c *cmd) FlushStdout() {
   145  	c.lock.Lock()
   146  	defer c.lock.Unlock()
   147  	if c.flushedStdout {
   148  		return
   149  	}
   150  	if data := c.stdout.Bytes(); len(data) > 0 {
   151  		fmt.Print(string(data))
   152  	}
   153  	c.flushedStdout = true
   154  }
   155  
   156  func (c *cmd) FlushStderr() {
   157  	c.lock.Lock()
   158  	defer c.lock.Unlock()
   159  	if c.flushedStderr {
   160  		return
   161  	}
   162  	if data := c.stderr.Bytes(); len(data) > 0 {
   163  		fmt.Fprint(os.Stderr, string(data))
   164  	}
   165  	c.flushedStderr = true
   166  }
   167  
   168  func (c *cmd) String() string {
   169  	if len(c.cmd.Args) == 0 {
   170  		return c.cmd.Path
   171  	}
   172  	return strings.Join(c.cmd.Args, " ")
   173  }
   174  
   175  func (c *cmd) wrapError(msg string, err error) error {
   176  	return fmt.Errorf("%v: %s: %v", c, msg, err)
   177  }
   178  
   179  func (c *cmd) debugPrintf(format string, args ...interface{}) {
   180  	if c.debug {
   181  		args = append([]interface{}{c}, args...)
   182  		log.Printf("%v: "+format, args...)
   183  	}
   184  }
   185  
   186  func cleanOutput(output string) string {
   187  	output = strings.TrimSpace(output)
   188  	lines := strings.Split(output, "\n")
   189  	cleanedLines := make([]string, 0, len(lines))
   190  	for _, line := range lines {
   191  		if line != "" {
   192  			cleanedLines = append(cleanedLines, strings.TrimSpace(line))
   193  		}
   194  	}
   195  	return strings.Join(cleanedLines, "\n")
   196  }