github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/apiserver/apiserver_test.go (about)

     1  // Copyright 2016 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package apiserver_test
     5  
     6  import (
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"fmt"
    10  	"io"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"net/url"
    14  	"strings"
    15  	"time"
    16  
    17  	"github.com/gorilla/websocket"
    18  	"github.com/juju/clock"
    19  	"github.com/juju/cmd/v3"
    20  	"github.com/juju/collections/set"
    21  	jujuhttp "github.com/juju/http/v2"
    22  	"github.com/juju/loggo"
    23  	"github.com/juju/names/v5"
    24  	jc "github.com/juju/testing/checkers"
    25  	"github.com/juju/utils/v3"
    26  	"github.com/juju/worker/v3/dependency"
    27  	"github.com/juju/worker/v3/workertest"
    28  	gc "gopkg.in/check.v1"
    29  
    30  	"github.com/juju/juju/api"
    31  	"github.com/juju/juju/apiserver"
    32  	"github.com/juju/juju/apiserver/apiserverhttp"
    33  	"github.com/juju/juju/apiserver/observer"
    34  	"github.com/juju/juju/apiserver/observer/fakeobserver"
    35  	"github.com/juju/juju/apiserver/stateauthenticator"
    36  	apitesting "github.com/juju/juju/apiserver/testing"
    37  	"github.com/juju/juju/apiserver/websocket/websockettest"
    38  	"github.com/juju/juju/core/auditlog"
    39  	"github.com/juju/juju/core/cache"
    40  	corelogger "github.com/juju/juju/core/logger"
    41  	"github.com/juju/juju/core/presence"
    42  	"github.com/juju/juju/jujuclient"
    43  	psapiserver "github.com/juju/juju/pubsub/apiserver"
    44  	"github.com/juju/juju/pubsub/centralhub"
    45  	"github.com/juju/juju/rpc/params"
    46  	"github.com/juju/juju/state"
    47  	statetesting "github.com/juju/juju/state/testing"
    48  	"github.com/juju/juju/testing"
    49  	"github.com/juju/juju/worker/gate"
    50  	"github.com/juju/juju/worker/modelcache"
    51  	"github.com/juju/juju/worker/multiwatcher"
    52  )
    53  
    54  const (
    55  	ownerPassword = "very very secret"
    56  )
    57  
    58  // apiserverConfigFixture provides a complete, valid, apiserver.Config.
    59  // Unfortunately this also means that it requires State, at least until
    60  // we update the tests to stop expecting state-based authentication.
    61  //
    62  // apiserverConfigFixture does not run an API server; see apiserverBaseSuite
    63  // for that.
    64  type apiserverConfigFixture struct {
    65  	statetesting.StateSuite
    66  	authenticator *stateauthenticator.Authenticator
    67  	mux           *apiserverhttp.Mux
    68  	tlsConfig     *tls.Config
    69  	config        apiserver.ServerConfig
    70  }
    71  
    72  func (s *apiserverConfigFixture) SetUpTest(c *gc.C) {
    73  	s.StateSuite.SetUpTest(c)
    74  
    75  	authenticator, err := stateauthenticator.NewAuthenticator(s.StatePool, clock.WallClock)
    76  	c.Assert(err, jc.ErrorIsNil)
    77  	s.authenticator = authenticator
    78  	s.mux = apiserverhttp.NewMux()
    79  
    80  	certPool, err := api.CreateCertPool(testing.CACert)
    81  	if err != nil {
    82  		panic(err)
    83  	}
    84  	s.tlsConfig = api.NewTLSConfig(certPool)
    85  	s.tlsConfig.ServerName = "juju-apiserver"
    86  	s.tlsConfig.Certificates = []tls.Certificate{*testing.ServerTLSCert}
    87  	s.mux = apiserverhttp.NewMux()
    88  	allWatcherBacking, err := state.NewAllWatcherBacking(s.StatePool)
    89  	c.Assert(err, jc.ErrorIsNil)
    90  	multiWatcherWorker, err := multiwatcher.NewWorker(multiwatcher.Config{
    91  		Clock:                clock.WallClock,
    92  		Logger:               loggo.GetLogger("test"),
    93  		Backing:              allWatcherBacking,
    94  		PrometheusRegisterer: noopRegisterer{},
    95  	})
    96  	c.Assert(err, jc.ErrorIsNil)
    97  
    98  	// The worker itself is a coremultiwatcher.Factory.
    99  	s.AddCleanup(func(c *gc.C) { workertest.CleanKill(c, multiWatcherWorker) })
   100  
   101  	machineTag := names.NewMachineTag("0")
   102  	hub := centralhub.New(machineTag, centralhub.PubsubNoOpMetrics{})
   103  
   104  	initialized := gate.NewLock()
   105  	modelCache, err := modelcache.NewWorker(modelcache.Config{
   106  		StatePool:            s.StatePool,
   107  		Hub:                  hub,
   108  		InitializedGate:      initialized,
   109  		Logger:               loggo.GetLogger("test"),
   110  		WatcherFactory:       multiWatcherWorker.WatchController,
   111  		PrometheusRegisterer: noopRegisterer{},
   112  		Cleanup:              func() {},
   113  	}.WithDefaultRestartStrategy())
   114  	c.Assert(err, jc.ErrorIsNil)
   115  	s.AddCleanup(func(c *gc.C) { workertest.CleanKill(c, modelCache) })
   116  
   117  	select {
   118  	case <-initialized.Unlocked():
   119  	case <-time.After(10 * time.Second):
   120  		c.Error("model cache not initialized after 10 seconds")
   121  	}
   122  
   123  	var controller *cache.Controller
   124  	err = modelcache.ExtractCacheController(modelCache, &controller)
   125  	c.Assert(err, jc.ErrorIsNil)
   126  
   127  	s.config = apiserver.ServerConfig{
   128  		StatePool:                  s.StatePool,
   129  		Controller:                 controller,
   130  		MultiwatcherFactory:        multiWatcherWorker,
   131  		LocalMacaroonAuthenticator: s.authenticator,
   132  		Clock:                      clock.WallClock,
   133  		GetAuditConfig:             func() auditlog.Config { return auditlog.Config{} },
   134  		Tag:                        machineTag,
   135  		DataDir:                    c.MkDir(),
   136  		LogDir:                     c.MkDir(),
   137  		Hub:                        hub,
   138  		Presence:                   presence.New(clock.WallClock),
   139  		LeaseManager:               apitesting.StubLeaseManager{},
   140  		Mux:                        s.mux,
   141  		NewObserver:                func() observer.Observer { return &fakeobserver.Instance{} },
   142  		UpgradeComplete:            func() bool { return true },
   143  		RegisterIntrospectionHandlers: func(f func(path string, h http.Handler)) {
   144  			f("navel", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   145  				io.WriteString(w, "gazing")
   146  			}))
   147  		},
   148  		MetricsCollector: apiserver.NewMetricsCollector(),
   149  		ExecEmbeddedCommand: func(ctx *cmd.Context, store jujuclient.ClientStore, whitelist []string, cmdPlusArgs string) int {
   150  			allowed := set.NewStrings(whitelist...)
   151  			args := strings.Split(cmdPlusArgs, " ")
   152  			if !allowed.Contains(args[0]) {
   153  				fmt.Fprintf(ctx.Stderr, "%q not allowed\n", args[0])
   154  				return 1
   155  			}
   156  			ctrl, err := store.CurrentController()
   157  			if err != nil {
   158  				fmt.Fprintf(ctx.Stderr, "%s", err.Error())
   159  				return 1
   160  			}
   161  			model, err := store.CurrentModel(ctrl)
   162  			if err != nil {
   163  				fmt.Fprintf(ctx.Stderr, "%s", err.Error())
   164  				return 1
   165  			}
   166  			ad, err := store.AccountDetails(ctrl)
   167  			if err != nil {
   168  				fmt.Fprintf(ctx.Stderr, "%s", err.Error())
   169  				return 1
   170  			}
   171  			if strings.Contains(cmdPlusArgs, "macaroon error") {
   172  				fmt.Fprintf(ctx.Stderr, "ERROR: cannot get discharge from https://controller")
   173  				fmt.Fprintf(ctx.Stderr, "\n")
   174  			} else {
   175  				cmdStr := fmt.Sprintf("%s@%s:%s -> %s", ad.User, ctrl, model, cmdPlusArgs)
   176  				fmt.Fprintf(ctx.Stdout, "%s", cmdStr)
   177  				fmt.Fprintf(ctx.Stdout, "\n")
   178  			}
   179  			return 0
   180  		},
   181  		SysLogger: noopSysLogger{},
   182  		DBGetter:  apiserver.StubDBGetter{},
   183  	}
   184  }
   185  
   186  type noopSysLogger struct{}
   187  
   188  func (noopSysLogger) Log([]corelogger.LogRecord) error { return nil }
   189  
   190  // apiserverBaseSuite runs an API server.
   191  type apiserverBaseSuite struct {
   192  	apiserverConfigFixture
   193  	server    *httptest.Server
   194  	apiServer *apiserver.Server
   195  	baseURL   *url.URL
   196  }
   197  
   198  func (s *apiserverBaseSuite) SetUpTest(c *gc.C) {
   199  	s.apiserverConfigFixture.SetUpTest(c)
   200  
   201  	s.server = httptest.NewUnstartedServer(s.mux)
   202  	s.server.TLS = s.tlsConfig
   203  	s.server.StartTLS()
   204  	s.AddCleanup(func(c *gc.C) { s.server.Close() })
   205  	baseURL, err := url.Parse(s.server.URL)
   206  	c.Assert(err, jc.ErrorIsNil)
   207  	s.baseURL = baseURL
   208  	c.Logf("started HTTP server listening on %q", s.server.Listener.Addr())
   209  
   210  	server, err := apiserver.NewServer(s.config)
   211  	c.Assert(err, jc.ErrorIsNil)
   212  	s.AddCleanup(func(c *gc.C) {
   213  		workertest.DirtyKill(c, server)
   214  	})
   215  	s.apiServer = server
   216  
   217  	loggo.GetLogger("juju.apiserver").SetLogLevel(loggo.TRACE)
   218  	u, err := s.State.User(s.Owner)
   219  	c.Assert(err, jc.ErrorIsNil)
   220  	err = u.SetPassword(ownerPassword)
   221  	c.Assert(err, jc.ErrorIsNil)
   222  }
   223  
   224  // URL returns a URL for this server with the given path and
   225  // query parameters. The URL scheme will be "https".
   226  func (s *apiserverBaseSuite) URL(path string, queryParams url.Values) *url.URL {
   227  	url := *s.baseURL
   228  	url.Path = path
   229  	url.RawQuery = queryParams.Encode()
   230  	return &url
   231  }
   232  
   233  // sendHTTPRequest sends an HTTP request with an appropriate
   234  // username and password.
   235  func (s *apiserverBaseSuite) sendHTTPRequest(c *gc.C, p apitesting.HTTPRequestParams) *http.Response {
   236  	p.Tag = s.Owner.String()
   237  	p.Password = ownerPassword
   238  	return apitesting.SendHTTPRequest(c, p)
   239  }
   240  
   241  func (s *apiserverBaseSuite) newServerNoCleanup(c *gc.C, config apiserver.ServerConfig) *apiserver.Server {
   242  	// To ensure we don't get two servers using the same mux (in which
   243  	// case the original api server always handles requests), ensure
   244  	// the original one is stopped.
   245  	s.apiServer.Kill()
   246  	err := s.apiServer.Wait()
   247  	c.Assert(err, jc.ErrorIsNil)
   248  	srv, err := apiserver.NewServer(config)
   249  	c.Assert(err, jc.ErrorIsNil)
   250  	return srv
   251  }
   252  
   253  func (s *apiserverBaseSuite) newServer(c *gc.C, config apiserver.ServerConfig) *apiserver.Server {
   254  	srv := s.newServerNoCleanup(c, config)
   255  	s.AddCleanup(func(c *gc.C) {
   256  		workertest.CleanKill(c, srv)
   257  	})
   258  	return srv
   259  }
   260  
   261  func (s *apiserverBaseSuite) newServerDirtyKill(c *gc.C, config apiserver.ServerConfig) *apiserver.Server {
   262  	srv := s.newServerNoCleanup(c, config)
   263  	s.AddCleanup(func(c *gc.C) {
   264  		workertest.DirtyKill(c, srv)
   265  	})
   266  	return srv
   267  }
   268  
   269  // APIInfo returns an info struct that has the server's address and ca-cert
   270  // populated.
   271  func (s *apiserverBaseSuite) APIInfo(server *apiserver.Server) *api.Info {
   272  	address := s.server.Listener.Addr().String()
   273  	return &api.Info{
   274  		Addrs:  []string{address},
   275  		CACert: testing.CACert,
   276  	}
   277  }
   278  
   279  func (s *apiserverBaseSuite) openAPIAs(c *gc.C, srv *apiserver.Server, tag names.Tag, password, nonce string, controllerOnly bool) api.Connection {
   280  	apiInfo := s.APIInfo(srv)
   281  	apiInfo.Tag = tag
   282  	apiInfo.Password = password
   283  	apiInfo.Nonce = nonce
   284  	if !controllerOnly {
   285  		apiInfo.ModelTag = s.Model.ModelTag()
   286  	}
   287  	conn, err := api.Open(apiInfo, api.DialOpts{})
   288  	c.Assert(err, jc.ErrorIsNil)
   289  	c.Assert(conn, gc.NotNil)
   290  	s.AddCleanup(func(c *gc.C) {
   291  		conn.Close()
   292  	})
   293  	return conn
   294  }
   295  
   296  // OpenAPIAsNewMachine creates a new client connection logging in as the
   297  // controller owner. The returned api.Connection should not be closed by the
   298  // caller as a cleanup function has been registered to do that.
   299  func (s *apiserverBaseSuite) OpenAPIAsAdmin(c *gc.C, srv *apiserver.Server) api.Connection {
   300  	return s.openAPIAs(c, srv, s.Owner, ownerPassword, "", false)
   301  }
   302  
   303  // OpenAPIAsNewMachine creates a new machine entry that lives in system state,
   304  // and then uses that to open the API. The returned api.Connection should not be
   305  // closed by the caller as a cleanup function has been registered to do that.
   306  // The machine will run the supplied jobs; if none are given, JobHostUnits is assumed.
   307  func (s *apiserverBaseSuite) OpenAPIAsNewMachine(c *gc.C, srv *apiserver.Server, jobs ...state.MachineJob) (api.Connection, *state.Machine) {
   308  	if len(jobs) == 0 {
   309  		jobs = []state.MachineJob{state.JobHostUnits}
   310  	}
   311  	machine, err := s.State.AddMachine(state.UbuntuBase("12.10"), jobs...)
   312  	c.Assert(err, jc.ErrorIsNil)
   313  	password, err := utils.RandomPassword()
   314  	c.Assert(err, jc.ErrorIsNil)
   315  	err = machine.SetPassword(password)
   316  	c.Assert(err, jc.ErrorIsNil)
   317  	err = machine.SetProvisioned("foo", "", "fake_nonce", nil)
   318  	c.Assert(err, jc.ErrorIsNil)
   319  	return s.openAPIAs(c, srv, machine.Tag(), password, "fake_nonce", false), machine
   320  }
   321  
   322  func dialWebsocketFromURL(c *gc.C, server string, header http.Header) (*websocket.Conn, *http.Response, error) {
   323  	// TODO(rogpeppe) merge this with the very similar dialWebsocket function.
   324  	if header == nil {
   325  		header = http.Header{}
   326  	}
   327  	header.Set("Origin", "http://localhost/")
   328  	caCerts := x509.NewCertPool()
   329  	c.Assert(caCerts.AppendCertsFromPEM([]byte(testing.CACert)), jc.IsTrue)
   330  	tlsConfig := jujuhttp.SecureTLSConfig()
   331  	tlsConfig.RootCAs = caCerts
   332  	tlsConfig.ServerName = "juju-apiserver"
   333  
   334  	dialer := &websocket.Dialer{
   335  		TLSClientConfig: tlsConfig,
   336  	}
   337  	return dialer.Dial(server, header)
   338  }
   339  
   340  type apiserverSuite struct {
   341  	apiserverBaseSuite
   342  }
   343  
   344  var _ = gc.Suite(&apiserverSuite{})
   345  
   346  func (s *apiserverSuite) TestCleanStop(c *gc.C) {
   347  	workertest.CleanKill(c, s.apiServer)
   348  }
   349  
   350  func (s *apiserverSuite) TestRestartMessage(c *gc.C) {
   351  	_, err := s.config.Hub.Publish(psapiserver.RestartTopic, psapiserver.Restart{
   352  		LocalOnly: true,
   353  	})
   354  	c.Assert(err, jc.ErrorIsNil)
   355  
   356  	err = workertest.CheckKilled(c, s.apiServer)
   357  	c.Assert(err, gc.Equals, dependency.ErrBounce)
   358  }
   359  
   360  func (s *apiserverSuite) getHealth(c *gc.C) (string, int) {
   361  	uri := s.server.URL + "/health"
   362  	resp := apitesting.SendHTTPRequest(c, apitesting.HTTPRequestParams{Method: "GET", URL: uri})
   363  	body, err := io.ReadAll(resp.Body)
   364  	c.Assert(err, jc.ErrorIsNil)
   365  	result := string(body)
   366  	// Ensure that the last value is a carriage return.
   367  	c.Assert(strings.HasSuffix(result, "\n"), jc.IsTrue)
   368  	return strings.TrimSuffix(result, "\n"), resp.StatusCode
   369  }
   370  
   371  func (s *apiserverSuite) TestHealthRunning(c *gc.C) {
   372  	health, statusCode := s.getHealth(c)
   373  	c.Assert(health, gc.Equals, "running")
   374  	c.Assert(statusCode, gc.Equals, http.StatusOK)
   375  }
   376  
   377  func (s *apiserverSuite) TestHealthStopping(c *gc.C) {
   378  	wg := apiserver.ServerWaitGroup(s.apiServer)
   379  	wg.Add(1)
   380  
   381  	s.apiServer.Kill()
   382  	// There is a race here between the test and the goroutine setting
   383  	// the value, so loop until we see the right health, then exit.
   384  	timeout := time.After(testing.LongWait)
   385  	for {
   386  		health, statusCode := s.getHealth(c)
   387  		if health == "stopping" {
   388  			// Expected, we're done.
   389  			c.Assert(statusCode, gc.Equals, http.StatusServiceUnavailable)
   390  			wg.Done()
   391  			return
   392  		}
   393  		select {
   394  		case <-timeout:
   395  			c.Fatalf("health not set to stopping")
   396  		case <-time.After(testing.ShortWait):
   397  			// Look again.
   398  		}
   399  	}
   400  }
   401  
   402  func (s *apiserverSuite) TestEmbeddedCommand(c *gc.C) {
   403  	cmdArgs := params.CLICommands{
   404  		User:     "fred",
   405  		Commands: []string{"status --color"},
   406  	}
   407  	s.assertEmbeddedCommand(c, cmdArgs, "fred@interactive:test-admin/testmodel -> status --color", nil)
   408  }
   409  
   410  func (s *apiserverSuite) TestEmbeddedCommandNotAllowed(c *gc.C) {
   411  	cmdArgs := params.CLICommands{
   412  		User:     "fred",
   413  		Commands: []string{"bootstrap aws"},
   414  	}
   415  	s.assertEmbeddedCommand(c, cmdArgs, `"bootstrap" not allowed`, nil)
   416  }
   417  
   418  func (s *apiserverSuite) TestEmbeddedCommandMissingUser(c *gc.C) {
   419  	cmdArgs := params.CLICommands{
   420  		Commands: []string{"status --color"},
   421  	}
   422  	s.assertEmbeddedCommand(c, cmdArgs, "", &params.Error{Message: `CLI command for anonymous user not supported`, Code: "not supported"})
   423  }
   424  
   425  func (s *apiserverSuite) TestEmbeddedCommandInvalidUser(c *gc.C) {
   426  	cmdArgs := params.CLICommands{
   427  		User:     "123@",
   428  		Commands: []string{"status --color"},
   429  	}
   430  	s.assertEmbeddedCommand(c, cmdArgs, "", &params.Error{Message: `user name "123@" not valid`, Code: params.CodeNotValid})
   431  }
   432  
   433  func (s *apiserverSuite) TestEmbeddedCommandInvalidMacaroon(c *gc.C) {
   434  	cmdArgs := params.CLICommands{
   435  		User:     "fred",
   436  		Commands: []string{"status macaroon error"},
   437  	}
   438  	s.assertEmbeddedCommand(c, cmdArgs, "", &params.Error{
   439  		Code:    params.CodeDischargeRequired,
   440  		Message: `macaroon discharge required: cannot get discharge from https://controller`})
   441  }
   442  
   443  func (s *apiserverSuite) assertEmbeddedCommand(c *gc.C, cmdArgs params.CLICommands, expected string, resultErr *params.Error) {
   444  	address := s.server.Listener.Addr().String()
   445  	path := fmt.Sprintf("/model/%s/commands", s.State.ModelUUID())
   446  	commandURL := &url.URL{
   447  		Scheme: "wss",
   448  		Host:   address,
   449  		Path:   path,
   450  	}
   451  	conn, _, err := dialWebsocketFromURL(c, commandURL.String(), http.Header{})
   452  	c.Assert(err, jc.ErrorIsNil)
   453  	defer conn.Close()
   454  
   455  	// Read back the nil error, indicating that all is well.
   456  	websockettest.AssertJSONInitialErrorNil(c, conn)
   457  
   458  	done := make(chan struct{})
   459  	var result params.CLICommandStatus
   460  	go func() {
   461  		for {
   462  			var update params.CLICommandStatus
   463  			err := conn.ReadJSON(&update)
   464  			c.Assert(err, jc.ErrorIsNil)
   465  
   466  			result.Output = append(result.Output, update.Output...)
   467  			result.Done = update.Done
   468  			result.Error = update.Error
   469  			if result.Done {
   470  				done <- struct{}{}
   471  				break
   472  			}
   473  		}
   474  	}()
   475  
   476  	err = conn.WriteJSON(cmdArgs)
   477  	c.Assert(err, jc.ErrorIsNil)
   478  
   479  	select {
   480  	case <-done:
   481  	case <-time.After(testing.LongWait):
   482  		c.Fatalf("no command result")
   483  	}
   484  
   485  	// Close connection.
   486  	err = conn.Close()
   487  	c.Assert(err, jc.ErrorIsNil)
   488  
   489  	var expectedOutput []string
   490  	if expected != "" {
   491  		expectedOutput = []string{expected}
   492  	}
   493  	c.Assert(result, jc.DeepEquals, params.CLICommandStatus{
   494  		Output: expectedOutput,
   495  		Done:   true,
   496  		Error:  resultErr,
   497  	})
   498  }