github.com/juju/juju@v0.0.0-20240327075706-a90865de2538/worker/controlsocket/worker_test.go (about)

     1  // Copyright 2023 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package controlsocket
     5  
     6  import (
     7  	"fmt"
     8  	"io/fs"
     9  	"net/http"
    10  	"os"
    11  	"path"
    12  
    13  	jc "github.com/juju/testing/checkers"
    14  	gc "gopkg.in/check.v1"
    15  )
    16  
    17  type workerSuite struct {
    18  	logger *fakeLogger
    19  }
    20  
    21  var _ = gc.Suite(&workerSuite{})
    22  
    23  func (s *workerSuite) SetUpTest(c *gc.C) {
    24  	s.logger = &fakeLogger{}
    25  }
    26  
    27  func (s *workerSuite) TestStartStopWorker(c *gc.C) {
    28  	tmpDir := c.MkDir()
    29  	socket := path.Join(tmpDir, "test.socket")
    30  
    31  	worker, err := NewWorker(Config{
    32  		State:      &fakeState{},
    33  		Logger:     s.logger,
    34  		SocketName: socket,
    35  	})
    36  	c.Assert(err, jc.ErrorIsNil)
    37  
    38  	// Check socket is created with correct permissions
    39  	fi, err := os.Stat(socket)
    40  	c.Assert(err, jc.ErrorIsNil)
    41  	c.Assert(fi.Mode(), gc.Equals, fs.ModeSocket|0700)
    42  
    43  	// Check server is up
    44  	cl := client(socket)
    45  	resp, err := cl.Get("http://a/foo")
    46  	c.Assert(err, jc.ErrorIsNil)
    47  	c.Assert(resp.StatusCode, gc.Equals, http.StatusNotFound)
    48  
    49  	worker.Kill()
    50  	err = worker.Wait()
    51  	c.Assert(err, jc.ErrorIsNil)
    52  
    53  	// Check server has stopped
    54  	resp, err = cl.Get("http://a/foo")
    55  	c.Assert(err, gc.ErrorMatches, ".*connection refused")
    56  
    57  	// No warnings/errors should have been logged
    58  	for _, entry := range s.logger.entries {
    59  		if entry.level == "ERROR" || entry.level == "WARNING" {
    60  			c.Errorf("%s: %s", entry.level, entry.msg)
    61  		}
    62  	}
    63  }
    64  
    65  type fakeLogger struct {
    66  	entries []logEntry
    67  }
    68  
    69  type logEntry struct{ level, msg string }
    70  
    71  func (f *fakeLogger) write(level string, format string, args ...any) {
    72  	f.entries = append(f.entries, logEntry{level, fmt.Sprintf(format, args...)})
    73  }
    74  
    75  func (f *fakeLogger) Errorf(format string, args ...any) {
    76  	f.write("ERROR", format, args...)
    77  }
    78  
    79  func (f *fakeLogger) Warningf(format string, args ...any) {
    80  	f.write("WARNING", format, args...)
    81  }
    82  
    83  func (f *fakeLogger) Infof(format string, args ...any) {
    84  	f.write("INFO", format, args...)
    85  }
    86  
    87  func (f *fakeLogger) Debugf(format string, args ...any) {
    88  	f.write("DEBUG", format, args...)
    89  }
    90  
    91  func (f *fakeLogger) Tracef(format string, args ...any) {
    92  	f.write("TRACE", format, args...)
    93  }