github.com/richardwilkes/toolbox@v1.121.0/atexit/atexit.go (about)

     1  // Copyright (c) 2016-2024 by Richard A. Wilkes. All rights reserved.
     2  //
     3  // This Source Code Form is subject to the terms of the Mozilla Public
     4  // License, version 2.0. If a copy of the MPL was not distributed with
     5  // this file, You can obtain one at http://mozilla.org/MPL/2.0/.
     6  //
     7  // This Source Code Form is "Incompatible With Secondary Licenses", as
     8  // defined by the Mozilla Public License, version 2.0.
     9  
    10  // Package atexit provides functionality similar to the C standard library's atexit() call.
    11  package atexit
    12  
    13  import (
    14  	"fmt"
    15  	"os"
    16  	"os/signal"
    17  	"runtime"
    18  	"slices"
    19  	"sync"
    20  	"syscall"
    21  
    22  	"github.com/richardwilkes/toolbox/errs"
    23  )
    24  
    25  var (
    26  	// RecoveryHandler will be used to capture any panics caused by functions that have been installed when run during
    27  	// exit. It may be set to nil to silently ignore them.
    28  	RecoveryHandler errs.RecoveryHandler = func(err error) { errs.Log(err) }
    29  	lock            sync.Mutex
    30  	nextID          = 1
    31  	pairs           []pair
    32  	exiting         bool
    33  )
    34  
    35  type pair struct {
    36  	f  func()
    37  	id int
    38  }
    39  
    40  // Register a function to be run at exit. Returns an ID that can be used to remove the function later, if desired.
    41  // Registering a function after Exit() has been called (i.e. in a function that was registered) will have no effect.
    42  func Register(f func()) int {
    43  	lock.Lock()
    44  	defer lock.Unlock()
    45  	if nextID == 1 {
    46  		sigChan := make(chan os.Signal, 2)
    47  		signal.Notify(sigChan, syscall.SIGINT, syscall.SIGTERM)
    48  		go waitForSigInt(sigChan)
    49  	}
    50  	pairs = append(pairs, pair{id: nextID, f: f})
    51  	nextID++
    52  	return nextID - 1
    53  }
    54  
    55  func waitForSigInt(sigChan <-chan os.Signal) {
    56  	s := <-sigChan
    57  	if s == syscall.SIGINT {
    58  		fmt.Print("\b\b") // Removes the unsightly ^C in the terminal
    59  	}
    60  	Exit(1)
    61  }
    62  
    63  // Unregister a function that was previously registered to be run at exit. If the ID is no longer present, nothing
    64  // happens. Unregistering a function after Exit() has been called (i.e. in a function that was registered) will have no
    65  // effect.
    66  func Unregister(id int) {
    67  	lock.Lock()
    68  	defer lock.Unlock()
    69  	pairs = slices.DeleteFunc(pairs, func(p pair) bool { return p.id == id })
    70  }
    71  
    72  // Exit runs any registered exit functions in the inverse order they were registered and then exits the progream with
    73  // the specified status. If a previous call to Exit() is already being handled, this method does nothing but does not
    74  // return. Recursive calls to Exit() will trigger a panic, which the exit handling will catch and report, but will then
    75  // proceed with exit as normal. Note that once Exit() is called, no subsequent changes to the registered list of
    76  // functions will have an effect (i.e. you cannot Unregister() a function inside an exit handler to prevent its
    77  // execution, nor can you Register() a new function).
    78  func Exit(status int) {
    79  	var pcs [512]uintptr
    80  	recursive := false
    81  	n := runtime.Callers(2, pcs[:])
    82  	frames := runtime.CallersFrames(pcs[:n])
    83  	for {
    84  		frame, more := frames.Next()
    85  		if frame.Function == "github.com/richardwilkes/toolbox/atexit.Exit" {
    86  			recursive = true
    87  			break
    88  		}
    89  		if !more {
    90  			break
    91  		}
    92  	}
    93  	var f []func()
    94  	lock.Lock()
    95  	wasExiting := exiting
    96  	if !wasExiting {
    97  		exiting = true
    98  		f = make([]func(), len(pairs))
    99  		for i, p := range pairs {
   100  			f[i] = p.f
   101  		}
   102  	}
   103  	lock.Unlock()
   104  	if wasExiting {
   105  		if recursive {
   106  			panic("recursive call of atexit.Exit()") // force the recovery mechanism to deal with it
   107  		}
   108  		select {} // halt progress so that we don't return
   109  	} else {
   110  		for i := len(f) - 1; i >= 0; i-- {
   111  			run(f[i])
   112  		}
   113  		os.Exit(status)
   114  	}
   115  }
   116  
   117  func run(f func()) {
   118  	defer errs.Recovery(RecoveryHandler)
   119  	f()
   120  }