github.com/vishnupahwa/lakctl@v0.0.2-alpha/control/app.go (about)

     1  package control
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"github.com/mitchellh/go-ps"
     7  	"github.com/vishnupahwa/lakctl/cmd/commands/options"
     8  	"log"
     9  	"os"
    10  	"os/exec"
    11  	"syscall"
    12  	"time"
    13  )
    14  
    15  // Start the command and the server
    16  func Start(ctx context.Context, run *options.Run, serve *options.Server) error {
    17  	cmdCtx, cancelFunc := context.WithCancel(ctx)
    18  	defer cancelFunc()
    19  	cmd := createCommand(cmdCtx, run)
    20  	must(cmd.Start())
    21  	cmdPtr := &cmd
    22  	runCtlServer(ctx, cmdPtr, cmdCtx, run, serve)
    23  	log.Println("lakctl closed")
    24  	return killGroupForProcess(*cmdPtr)
    25  }
    26  
    27  func createCommand(cmdCtx context.Context, run *options.Run) *exec.Cmd {
    28  	c := exec.CommandContext(cmdCtx, "bash", "-c", run.Command)
    29  	c.SysProcAttr = &syscall.SysProcAttr{Setpgid: true}
    30  	c.Stdout = os.Stdout
    31  	c.Stderr = os.Stderr
    32  	return c
    33  }
    34  
    35  // killGroupForProcess checks the PID is running and is a child process of lakctl before killing it's group and itself.
    36  // The cmd process kill and the wait is a safety check to make sure the process tree has fully been terminated.
    37  func killGroupForProcess(cmd *exec.Cmd) error {
    38  	pid, err := ps.FindProcess(cmd.Process.Pid)
    39  	if err != nil {
    40  		log.Printf("Cannot find PID %d: %v\n", cmd.Process.Pid, err)
    41  		return nil
    42  	}
    43  	if pid == nil {
    44  		log.Printf("Command previously running with %d is not running\n", cmd.Process.Pid)
    45  		return nil
    46  	}
    47  	if pid.PPid() != os.Getpid() {
    48  		log.Printf("Process %d has parent %d and is not a subprocess of lakctl (%d).", pid.Pid(), pid.PPid(), os.Getpid())
    49  		return nil
    50  	}
    51  	log.Printf("Stopping %s (PID: %d, PPID: %d)", pid.Executable(), pid.Pid(), pid.PPid())
    52  	errGroup := syscall.Kill(-pid.Pid(), syscall.SIGTERM)
    53  	errCmd := cmd.Process.Signal(syscall.SIGTERM)
    54  	waitWithTimeout(cmd)
    55  	return compositeErr(errCmd, errGroup)
    56  }
    57  
    58  func waitWithTimeout(cmd *exec.Cmd) {
    59  	wait := make(chan error, 1)
    60  	go func() { wait <- cmd.Wait() }()
    61  	select {
    62  	case <-wait:
    63  		return
    64  	case <-time.After(30 * time.Second):
    65  		log.Printf("Timed out waiting 30s for process to be killed\n")
    66  	}
    67  }
    68  
    69  func compositeErr(errs ...error) error {
    70  	message := ""
    71  	for _, err := range errs {
    72  		if err != nil {
    73  			message = message + err.Error() + " "
    74  		}
    75  	}
    76  	if len(message) > 0 {
    77  		return errors.New(message)
    78  	}
    79  	return nil
    80  }
    81  
    82  func must(err error) {
    83  	if err != nil {
    84  		log.Fatal(err)
    85  	}
    86  }