github.com/safing/portbase@v0.19.5/run/main.go (about)

     1  package run
     2  
     3  import (
     4  	"bufio"
     5  	"errors"
     6  	"flag"
     7  	"fmt"
     8  	"io"
     9  	"os"
    10  	"os/signal"
    11  	"runtime/pprof"
    12  	"syscall"
    13  	"time"
    14  
    15  	"github.com/safing/portbase/log"
    16  	"github.com/safing/portbase/modules"
    17  )
    18  
    19  var (
    20  	printStackOnExit   bool
    21  	enableInputSignals bool
    22  
    23  	sigUSR1 = syscall.Signal(0xa) // dummy for windows
    24  )
    25  
    26  func init() {
    27  	flag.BoolVar(&printStackOnExit, "print-stack-on-exit", false, "prints the stack before of shutting down")
    28  	flag.BoolVar(&enableInputSignals, "input-signals", false, "emulate signals using stdin")
    29  }
    30  
    31  // Run execute a full program lifecycle (including signal handling) based on modules. Just empty-import required packages and do os.Exit(run.Run()).
    32  func Run() int {
    33  	// Start
    34  	err := modules.Start()
    35  	if err != nil {
    36  		// Immediately return for a clean exit.
    37  		if errors.Is(err, modules.ErrCleanExit) {
    38  			return 0
    39  		}
    40  
    41  		if printStackOnExit {
    42  			printStackTo(os.Stdout, "PRINTING STACK ON EXIT (STARTUP ERROR)")
    43  		}
    44  
    45  		// Trigger shutdown and wait for it to complete.
    46  		_ = modules.Shutdown()
    47  		exitCode := modules.GetExitStatusCode()
    48  
    49  		// Return the exit code, if it was set.
    50  		if exitCode > 0 {
    51  			return exitCode
    52  		}
    53  
    54  		// Otherwise, return a default 1.
    55  		return 1
    56  	}
    57  
    58  	// Shutdown
    59  	// catch interrupt for clean shutdown
    60  	signalCh := make(chan os.Signal, 1)
    61  	if enableInputSignals {
    62  		go inputSignals(signalCh)
    63  	}
    64  	signal.Notify(
    65  		signalCh,
    66  		os.Interrupt,
    67  		syscall.SIGHUP,
    68  		syscall.SIGINT,
    69  		syscall.SIGTERM,
    70  		syscall.SIGQUIT,
    71  		sigUSR1,
    72  	)
    73  
    74  signalLoop:
    75  	for {
    76  		select {
    77  		case sig := <-signalCh:
    78  			// only print and continue to wait if SIGUSR1
    79  			if sig == sigUSR1 {
    80  				printStackTo(os.Stderr, "PRINTING STACK ON REQUEST")
    81  				continue signalLoop
    82  			}
    83  
    84  			fmt.Println(" <INTERRUPT>")
    85  			log.Warning("main: program was interrupted, shutting down.")
    86  
    87  			// catch signals during shutdown
    88  			go func() {
    89  				forceCnt := 5
    90  				for {
    91  					<-signalCh
    92  					forceCnt--
    93  					if forceCnt > 0 {
    94  						fmt.Printf(" <INTERRUPT> again, but already shutting down. %d more to force.\n", forceCnt)
    95  					} else {
    96  						printStackTo(os.Stderr, "PRINTING STACK ON FORCED EXIT")
    97  						os.Exit(1)
    98  					}
    99  				}
   100  			}()
   101  
   102  			if printStackOnExit {
   103  				printStackTo(os.Stdout, "PRINTING STACK ON EXIT")
   104  			}
   105  
   106  			go func() {
   107  				time.Sleep(3 * time.Minute)
   108  				printStackTo(os.Stderr, "PRINTING STACK - TAKING TOO LONG FOR SHUTDOWN")
   109  				os.Exit(1)
   110  			}()
   111  
   112  			_ = modules.Shutdown()
   113  			break signalLoop
   114  
   115  		case <-modules.ShuttingDown():
   116  			break signalLoop
   117  		}
   118  	}
   119  
   120  	// wait for shutdown to complete, then exit
   121  	return modules.GetExitStatusCode()
   122  }
   123  
   124  func inputSignals(signalCh chan os.Signal) {
   125  	scanner := bufio.NewScanner(os.Stdin)
   126  	for scanner.Scan() {
   127  		switch scanner.Text() {
   128  		case "SIGHUP":
   129  			signalCh <- syscall.SIGHUP
   130  		case "SIGINT":
   131  			signalCh <- syscall.SIGINT
   132  		case "SIGQUIT":
   133  			signalCh <- syscall.SIGQUIT
   134  		case "SIGTERM":
   135  			signalCh <- syscall.SIGTERM
   136  		case "SIGUSR1":
   137  			signalCh <- sigUSR1
   138  		}
   139  	}
   140  }
   141  
   142  func printStackTo(writer io.Writer, msg string) {
   143  	_, err := fmt.Fprintf(writer, "===== %s =====\n", msg)
   144  	if err == nil {
   145  		err = pprof.Lookup("goroutine").WriteTo(writer, 1)
   146  	}
   147  	if err != nil {
   148  		log.Errorf("main: failed to write stack trace: %s", err)
   149  	}
   150  }