github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/worker/simplesignalhandler/signalwatcher_test.go (about)

     1  // Copyright 2023 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package simplesignalhandler_test
     5  
     6  import (
     7  	"fmt"
     8  	"os"
     9  	"syscall"
    10  
    11  	"github.com/juju/errors"
    12  	"github.com/juju/loggo"
    13  	jc "github.com/juju/testing/checkers"
    14  	. "gopkg.in/check.v1"
    15  
    16  	ssh "github.com/juju/juju/worker/simplesignalhandler"
    17  )
    18  
    19  type signalSuite struct {
    20  }
    21  
    22  var _ = Suite(&signalSuite{})
    23  
    24  func (_ *signalSuite) TestSignalHandling(c *C) {
    25  	testErr := errors.ConstError("test")
    26  	handler := ssh.SignalHandlerFunc(func(sig os.Signal) error {
    27  		return testErr
    28  	})
    29  
    30  	sigChan := make(chan os.Signal, 0)
    31  
    32  	watcher, err := ssh.NewSignalWatcher(loggo.Logger{}, sigChan, handler)
    33  	c.Assert(err, jc.ErrorIsNil)
    34  
    35  	sigChan <- syscall.SIGTERM
    36  
    37  	err = watcher.Wait()
    38  	c.Assert(errors.Is(err, testErr), jc.IsTrue)
    39  }
    40  
    41  func (_ *signalSuite) TestSignalHandlingClosed(c *C) {
    42  	handler := ssh.SignalHandlerFunc(func(sig os.Signal) error {
    43  		return fmt.Errorf("should not be called")
    44  	})
    45  
    46  	sigChan := make(chan os.Signal, 0)
    47  
    48  	watcher, err := ssh.NewSignalWatcher(loggo.Logger{}, sigChan, handler)
    49  	c.Assert(err, jc.ErrorIsNil)
    50  
    51  	close(sigChan)
    52  
    53  	err = watcher.Wait()
    54  	c.Assert(err.Error(), Equals, "signal channel closed unexpectedly")
    55  }
    56  
    57  func (_ *signalSuite) TestDefaultSignalHandlerNilMap(c *C) {
    58  	testErr := errors.ConstError("test")
    59  	err := ssh.SignalHandler(testErr, nil)(syscall.SIGTERM)
    60  	c.Assert(errors.Is(err, testErr), jc.IsTrue)
    61  }
    62  
    63  func (_ *signalSuite) TestDefaultSignalHandlerNoMap(c *C) {
    64  	testErr := errors.ConstError("test")
    65  	err := ssh.SignalHandler(testErr, map[os.Signal]error{
    66  		syscall.SIGINT: errors.New("test error"),
    67  	})(syscall.SIGTERM)
    68  	c.Assert(errors.Is(err, testErr), jc.IsTrue)
    69  }
    70  
    71  func (_ *signalSuite) TestDefaultSignalHandlerMap(c *C) {
    72  	testErr := errors.ConstError("test")
    73  	err := ssh.SignalHandler(testErr, map[os.Signal]error{
    74  		syscall.SIGINT: errors.New("test error"),
    75  	})(syscall.SIGINT)
    76  	c.Assert(errors.Is(err, testErr), jc.IsFalse)
    77  	c.Assert(err.Error(), Equals, "test error")
    78  }