github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/api/apiclient_test.go (about)

     1  // Copyright 2014 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package api_test
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"crypto/tls"
    10  	"crypto/x509"
    11  	"encoding/pem"
    12  	"fmt"
    13  	"net"
    14  	"net/http"
    15  	"net/http/httptest"
    16  	"net/url"
    17  	"os"
    18  	"reflect"
    19  	"strings"
    20  	"sync"
    21  	"sync/atomic"
    22  	"time"
    23  
    24  	"github.com/juju/clock"
    25  	"github.com/juju/clock/testclock"
    26  	"github.com/juju/errors"
    27  	"github.com/juju/loggo"
    28  	"github.com/juju/names/v5"
    29  	proxyutils "github.com/juju/proxy"
    30  	"github.com/juju/testing"
    31  	jc "github.com/juju/testing/checkers"
    32  	gc "gopkg.in/check.v1"
    33  
    34  	"github.com/juju/juju/api"
    35  	"github.com/juju/juju/api/base"
    36  	apiclient "github.com/juju/juju/api/client/client"
    37  	"github.com/juju/juju/api/common"
    38  	apitesting "github.com/juju/juju/api/testing"
    39  	apiservertesting "github.com/juju/juju/apiserver/testing"
    40  	"github.com/juju/juju/controller"
    41  	"github.com/juju/juju/core/network"
    42  	jjtesting "github.com/juju/juju/juju/testing"
    43  	"github.com/juju/juju/rpc"
    44  	"github.com/juju/juju/rpc/jsoncodec"
    45  	"github.com/juju/juju/rpc/params"
    46  	jtesting "github.com/juju/juju/testing"
    47  	"github.com/juju/juju/utils/proxy"
    48  	jujuversion "github.com/juju/juju/version"
    49  )
    50  
    51  type apiclientSuite struct {
    52  	jjtesting.JujuConnSuite
    53  }
    54  
    55  var _ = gc.Suite(&apiclientSuite{})
    56  
    57  func (s *apiclientSuite) TestDialAPIToModel(c *gc.C) {
    58  	info := s.APIInfo(c)
    59  	conn, location, err := api.DialAPI(info, api.DialOpts{})
    60  	c.Assert(err, jc.ErrorIsNil)
    61  	defer conn.Close()
    62  	assertConnAddrForModel(c, location, info.Addrs[0], s.State.ModelUUID())
    63  }
    64  
    65  func (s *apiclientSuite) TestDialAPIToRoot(c *gc.C) {
    66  	info := s.APIInfo(c)
    67  	info.ModelTag = names.NewModelTag("")
    68  	conn, location, err := api.DialAPI(info, api.DialOpts{})
    69  	c.Assert(err, jc.ErrorIsNil)
    70  	defer conn.Close()
    71  	assertConnAddrForRoot(c, location, info.Addrs[0])
    72  }
    73  
    74  func (s *apiclientSuite) TestDialAPIMultiple(c *gc.C) {
    75  	// Create a socket that proxies to the API server.
    76  	info := s.APIInfo(c)
    77  	serverAddr := info.Addrs[0]
    78  	proxy := testing.NewTCPProxy(c, serverAddr)
    79  	defer proxy.Close()
    80  
    81  	// Check that we can use the proxy to connect.
    82  	info.Addrs = []string{proxy.Addr()}
    83  	conn, location, err := api.DialAPI(info, api.DialOpts{})
    84  	c.Assert(err, jc.ErrorIsNil)
    85  	conn.Close()
    86  	assertConnAddrForModel(c, location, proxy.Addr(), s.State.ModelUUID())
    87  
    88  	// Now break Addrs[0], and ensure that Addrs[1]
    89  	// is successfully connected to.
    90  	proxy.Close()
    91  
    92  	info.Addrs = []string{proxy.Addr(), serverAddr}
    93  	conn, location, err = api.DialAPI(info, api.DialOpts{})
    94  	c.Assert(err, jc.ErrorIsNil)
    95  	conn.Close()
    96  	assertConnAddrForModel(c, location, serverAddr, s.State.ModelUUID())
    97  }
    98  
    99  func (s *apiclientSuite) TestDialAPIWithProxy(c *gc.C) {
   100  	info := s.APIInfo(c)
   101  	opts := api.DialOpts{IPAddrResolver: apitesting.IPAddrResolverMap{
   102  		"testing.invalid": {"0.1.1.1"},
   103  	}}
   104  	fakeAddr := "testing.invalid:1234"
   105  
   106  	// Confirm that the proxy configuration is used. See:
   107  	//     https://bugs.launchpad.net/juju/+bug/1698989
   108  	//
   109  	// TODO(axw) use github.com/elazarl/goproxy set up a real
   110  	// forward proxy, and confirm that we can dial a successful
   111  	// connection.
   112  	handler := func(w http.ResponseWriter, r *http.Request) {
   113  		if r.Method != "CONNECT" {
   114  			http.Error(w, fmt.Sprintf("invalid method %s", r.Method), http.StatusMethodNotAllowed)
   115  			return
   116  		}
   117  		if r.URL.Host != fakeAddr {
   118  			http.Error(w, fmt.Sprintf("unexpected host %s", r.URL.Host), http.StatusBadRequest)
   119  			return
   120  		}
   121  		http.Error(w, "🍵", http.StatusTeapot)
   122  	}
   123  	proxyServer := httptest.NewServer(http.HandlerFunc(handler))
   124  	defer proxyServer.Close()
   125  
   126  	err := proxy.DefaultConfig.Set(proxyutils.Settings{
   127  		Https: proxyServer.Listener.Addr().String(),
   128  	})
   129  	c.Assert(err, jc.ErrorIsNil)
   130  	defer proxy.DefaultConfig.Set(proxyutils.Settings{})
   131  
   132  	// Check that we can use the proxy to connect.
   133  	info.Addrs = []string{fakeAddr}
   134  	_, _, err = api.DialAPI(info, opts)
   135  	c.Assert(err, gc.ErrorMatches, "unable to connect to API: I'm a teapot")
   136  }
   137  
   138  func (s *apiclientSuite) TestDialAPIMultipleError(c *gc.C) {
   139  	var addrs []string
   140  
   141  	// count holds the number of times we've accepted a connection.
   142  	var count int32
   143  	for i := 0; i < 3; i++ {
   144  		listener, err := net.Listen("tcp", "127.0.0.1:0")
   145  		c.Assert(err, jc.ErrorIsNil)
   146  		defer listener.Close()
   147  		addrs = append(addrs, listener.Addr().String())
   148  		go func() {
   149  			for {
   150  				client, err := listener.Accept()
   151  				if err != nil {
   152  					return
   153  				}
   154  				atomic.AddInt32(&count, 1)
   155  				client.Close()
   156  			}
   157  		}()
   158  	}
   159  	info := s.APIInfo(c)
   160  	info.Addrs = addrs
   161  	_, _, err := api.DialAPI(info, api.DialOpts{})
   162  	c.Assert(err, gc.ErrorMatches, `unable to connect to API: .*`)
   163  	c.Assert(atomic.LoadInt32(&count), gc.Equals, int32(3))
   164  }
   165  
   166  func (s *apiclientSuite) TestVerifyCA(c *gc.C) {
   167  	decodedCACert, _ := pem.Decode([]byte(jtesting.CACert))
   168  	serverCertWithoutCA, _ := tls.X509KeyPair([]byte(jtesting.ServerCert), []byte(jtesting.ServerKey))
   169  	serverCertWithSelfSignedCA, _ := tls.X509KeyPair([]byte(jtesting.ServerCert), []byte(jtesting.ServerKey))
   170  	serverCertWithSelfSignedCA.Certificate = append(serverCertWithSelfSignedCA.Certificate, decodedCACert.Bytes)
   171  
   172  	specs := []struct {
   173  		descr        string
   174  		serverCert   tls.Certificate
   175  		verifyCA     func(host, endpoint string, caCert *x509.Certificate) error
   176  		expConnCount int32
   177  		errRegex     string
   178  	}{
   179  		{
   180  			descr:      "VerifyCA provided but server does not present a CA cert",
   181  			serverCert: serverCertWithoutCA,
   182  			verifyCA: func(host, endpoint string, caCert *x509.Certificate) error {
   183  				return errors.New("VerifyCA should not be called")
   184  			},
   185  			// Dial tries to fetch CAs, doesn't find any and
   186  			// proceeds with the connection to the servers. This
   187  			// would be the case where we connect to an older juju
   188  			// controller.
   189  			expConnCount: 2,
   190  			errRegex:     `unable to connect to API: .*`,
   191  		},
   192  		{
   193  			descr:      "no VerifyCA provided",
   194  			serverCert: serverCertWithSelfSignedCA,
   195  			// Dial connects to all servers
   196  			expConnCount: 1,
   197  			errRegex:     `unable to connect to API: .*`,
   198  		},
   199  		{
   200  			descr:      "VerifyCA that always rejects certs",
   201  			serverCert: serverCertWithSelfSignedCA,
   202  			verifyCA: func(host, endpoint string, caCert *x509.Certificate) error {
   203  				return errors.New("CA not trusted")
   204  			},
   205  			// Dial aborts after fetching CAs
   206  			expConnCount: 1,
   207  			errRegex:     "CA not trusted",
   208  		},
   209  		{
   210  			descr:      "VerifyCA that always accepts certs",
   211  			serverCert: serverCertWithSelfSignedCA,
   212  			verifyCA: func(host, endpoint string, caCert *x509.Certificate) error {
   213  				return nil
   214  			},
   215  			// Dial fetches CAs and then proceeds with the connection to the servers
   216  			expConnCount: 2,
   217  			errRegex:     `unable to connect to API: .*`,
   218  		},
   219  	}
   220  
   221  	info := s.APIInfo(c)
   222  	for specIndex, spec := range specs {
   223  		c.Logf("test %d: %s", specIndex, spec.descr)
   224  
   225  		// connCount holds the number of times we've accepted a connection.
   226  		var connCount int32
   227  		tlsConf := &tls.Config{
   228  			Certificates: []tls.Certificate{spec.serverCert},
   229  		}
   230  
   231  		listener, err := tls.Listen("tcp", "127.0.0.1:0", tlsConf)
   232  		c.Assert(err, jc.ErrorIsNil)
   233  		defer listener.Close()
   234  		go func() {
   235  			buf := make([]byte, 4)
   236  			for {
   237  				client, err := listener.Accept()
   238  				if err != nil {
   239  					return
   240  				}
   241  				atomic.AddInt32(&connCount, 1)
   242  
   243  				// Do a dummy read to prevent the connection from
   244  				// closing before the client can access the certs.
   245  				_, _ = client.Read(buf)
   246  				_ = client.Close()
   247  			}
   248  		}()
   249  
   250  		atomic.StoreInt32(&connCount, 0)
   251  		info.Addrs = []string{listener.Addr().String()}
   252  		_, _, err = api.DialAPI(info, api.DialOpts{
   253  			VerifyCA: spec.verifyCA,
   254  		})
   255  		c.Assert(err, gc.ErrorMatches, spec.errRegex)
   256  		c.Assert(atomic.LoadInt32(&connCount), gc.Equals, spec.expConnCount)
   257  	}
   258  }
   259  
   260  func (s *apiclientSuite) TestOpen(c *gc.C) {
   261  	info := s.APIInfo(c)
   262  	st, err := api.Open(info, api.DialOpts{})
   263  	c.Assert(err, jc.ErrorIsNil)
   264  	defer st.Close()
   265  
   266  	c.Assert(st.Addr(), gc.Equals, info.Addrs[0])
   267  	modelTag, ok := st.ModelTag()
   268  	c.Assert(ok, jc.IsTrue)
   269  	c.Assert(modelTag, gc.Equals, s.Model.ModelTag())
   270  
   271  	remoteVersion, versionSet := st.ServerVersion()
   272  	c.Assert(versionSet, jc.IsTrue)
   273  	c.Assert(remoteVersion, gc.Equals, jujuversion.Current)
   274  
   275  	c.Assert(api.CookieURL(st).String(), gc.Equals, "https://deadbeef-1bad-500d-9000-4b1d0d06f00d/")
   276  }
   277  
   278  func (s *apiclientSuite) TestOpenCookieURLUsesSNIHost(c *gc.C) {
   279  	info := s.APIInfo(c)
   280  	info.SNIHostName = "somehost"
   281  	st, err := api.Open(info, api.DialOpts{})
   282  	c.Assert(err, jc.ErrorIsNil)
   283  	defer st.Close()
   284  
   285  	c.Assert(api.CookieURL(st).String(), gc.Equals, "https://somehost/")
   286  }
   287  
   288  func (s *apiclientSuite) TestOpenCookieURLDefaultsToAddress(c *gc.C) {
   289  	info := s.APIInfo(c)
   290  	info.ControllerUUID = ""
   291  	st, err := api.Open(info, api.DialOpts{})
   292  	c.Assert(err, jc.ErrorIsNil)
   293  	defer st.Close()
   294  
   295  	c.Assert(api.CookieURL(st).String(), gc.Matches, "https://localhost:.*/")
   296  }
   297  
   298  func (s *apiclientSuite) TestOpenHonorsModelTag(c *gc.C) {
   299  	info := s.APIInfo(c)
   300  
   301  	// TODO(jam): 2014-06-05 http://pad.lv/1326802
   302  	// we want to test this eventually, but for now s.APIInfo uses
   303  	// conn.StateInfo() which doesn't know about ModelTag.
   304  	// c.Check(info.ModelTag, gc.Equals, model.Tag())
   305  	// c.Assert(info.ModelTag, gc.Not(gc.Equals), "")
   306  
   307  	// We start by ensuring we have an invalid tag, and Open should fail.
   308  	info.ModelTag = names.NewModelTag("0b501e7e-cafe-f00d-ba1d-b1a570c0e199")
   309  	_, err := api.Open(info, api.DialOpts{})
   310  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
   311  		Message: `unknown model: "0b501e7e-cafe-f00d-ba1d-b1a570c0e199"`,
   312  		Code:    "model not found",
   313  	})
   314  	c.Check(params.ErrCode(err), gc.Equals, params.CodeModelNotFound)
   315  
   316  	// Now set it to the right tag, and we should succeed.
   317  	info.ModelTag = s.Model.ModelTag()
   318  	st, err := api.Open(info, api.DialOpts{})
   319  	c.Assert(err, jc.ErrorIsNil)
   320  	st.Close()
   321  
   322  	// Backwards compatibility, we should succeed if we do not set an
   323  	// model tag
   324  	info.ModelTag = names.NewModelTag("")
   325  	st, err = api.Open(info, api.DialOpts{})
   326  	c.Assert(err, jc.ErrorIsNil)
   327  	st.Close()
   328  }
   329  
   330  func (s *apiclientSuite) TestServerRoot(c *gc.C) {
   331  	url := api.ServerRoot(s.APIState)
   332  	c.Assert(url, gc.Matches, "https://localhost:[0-9]+")
   333  }
   334  
   335  func (s *apiclientSuite) TestDialWebsocketStopsOtherDialAttempts(c *gc.C) {
   336  	// Try to open the API with two addresses.
   337  	// Wait for connection attempts to both.
   338  	// Let one succeed.
   339  	// Wait for the other to be canceled.
   340  
   341  	type dialResponse struct {
   342  		conn jsoncodec.JSONConn
   343  	}
   344  	type dialInfo struct {
   345  		ctx      context.Context
   346  		location string
   347  		replyc   chan<- dialResponse
   348  	}
   349  	dialed := make(chan dialInfo)
   350  	fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) {
   351  		reply := make(chan dialResponse)
   352  		dialed <- dialInfo{
   353  			ctx:      ctx,
   354  			location: urlStr,
   355  			replyc:   reply,
   356  		}
   357  		r := <-reply
   358  		return r.conn, nil
   359  	}
   360  	conn0 := fakeConn{}
   361  	clock := testclock.NewClock(time.Now())
   362  	openDone := make(chan struct{})
   363  	const dialAddressInterval = 50 * time.Millisecond
   364  	go func() {
   365  		defer close(openDone)
   366  		conn, err := api.Open(&api.Info{
   367  			Addrs: []string{
   368  				"place1.example:1234",
   369  				"place2.example:1234",
   370  			},
   371  			SkipLogin: true,
   372  			CACert:    jtesting.CACert,
   373  		}, api.DialOpts{
   374  			Timeout:             5 * time.Second,
   375  			RetryDelay:          1 * time.Second,
   376  			DialAddressInterval: dialAddressInterval,
   377  			DialWebsocket:       fakeDialer,
   378  			Clock:               clock,
   379  			IPAddrResolver: apitesting.IPAddrResolverMap{
   380  				"place1.example": {"0.1.1.1"},
   381  				"place2.example": {"0.2.2.2"},
   382  			},
   383  		})
   384  		c.Check(api.UnderlyingConn(conn), gc.Equals, conn0)
   385  		c.Check(err, jc.ErrorIsNil)
   386  	}()
   387  
   388  	place1 := "wss://place1.example:1234/api"
   389  	place2 := "wss://place2.example:1234/api"
   390  	// Wait for first connection, but don't
   391  	// reply immediately because we want
   392  	// to wait for the second connection before
   393  	// letting the first one succeed.
   394  	var info0 dialInfo
   395  	select {
   396  	case info0 = <-dialed:
   397  	case <-time.After(jtesting.LongWait):
   398  		c.Fatalf("timed out waiting for dial")
   399  	}
   400  	this := place1
   401  	other := place2
   402  	if info0.location != place1 {
   403  		// We now randomly order what we will connect to. So we check
   404  		// whether we first tried to connect to place1 or place2.
   405  		// However, we should still be able to interrupt a second dial by
   406  		// having the first one succeed.
   407  		this = place2
   408  		other = place1
   409  	}
   410  
   411  	c.Assert(info0.location, gc.Equals, this)
   412  
   413  	var info1 dialInfo
   414  	// Wait for the next dial to be made. Note that we wait for two
   415  	// waiters because ContextWithTimeout as created by the
   416  	// outer level of api.Open also waits.
   417  	err := clock.WaitAdvance(dialAddressInterval, time.Second, 2)
   418  	c.Assert(err, jc.ErrorIsNil)
   419  
   420  	select {
   421  	case info1 = <-dialed:
   422  	case <-time.After(jtesting.LongWait):
   423  		c.Fatalf("timed out waiting for dial")
   424  	}
   425  	c.Assert(info1.location, gc.Equals, other)
   426  
   427  	// Allow the first dial to succeed.
   428  	info0.replyc <- dialResponse{
   429  		conn: conn0,
   430  	}
   431  
   432  	// The Open returns immediately without waiting
   433  	// for the second dial to complete.
   434  	select {
   435  	case <-openDone:
   436  	case <-time.After(jtesting.LongWait):
   437  		c.Fatalf("timed out waiting for connection")
   438  	}
   439  
   440  	// The second dial's context is canceled to tell
   441  	// it to stop.
   442  	select {
   443  	case <-info1.ctx.Done():
   444  	case <-time.After(jtesting.LongWait):
   445  		c.Fatalf("timed out waiting for context to be closed")
   446  	}
   447  	conn1 := fakeConn{
   448  		closed: make(chan struct{}),
   449  	}
   450  	// Allow the second dial to succeed.
   451  	info1.replyc <- dialResponse{
   452  		conn: conn1,
   453  	}
   454  	// Check that the connection it returns is closed.
   455  	select {
   456  	case <-conn1.closed:
   457  	case <-time.After(jtesting.LongWait):
   458  		c.Fatalf("timed out waiting for connection to be closed")
   459  	}
   460  }
   461  
   462  type apiDialInfo struct {
   463  	location   string
   464  	hasRootCAs bool
   465  	serverName string
   466  }
   467  
   468  var openWithSNIHostnameTests = []struct {
   469  	about      string
   470  	info       *api.Info
   471  	expectDial apiDialInfo
   472  }{{
   473  	about: "no cert; DNS name - use SNI hostname",
   474  	info: &api.Info{
   475  		Addrs:       []string{"foo.com:1234"},
   476  		SNIHostName: "foo.com",
   477  		SkipLogin:   true,
   478  	},
   479  	expectDial: apiDialInfo{
   480  		location:   "wss://foo.com:1234/api",
   481  		hasRootCAs: false,
   482  		serverName: "foo.com",
   483  	},
   484  }, {
   485  	about: "no cert; numeric IP address - use SNI hostname",
   486  	info: &api.Info{
   487  		Addrs:       []string{"0.1.2.3:1234"},
   488  		SNIHostName: "foo.com",
   489  		SkipLogin:   true,
   490  	},
   491  	expectDial: apiDialInfo{
   492  		location:   "wss://0.1.2.3:1234/api",
   493  		hasRootCAs: false,
   494  		serverName: "foo.com",
   495  	},
   496  }, {
   497  	about: "with cert; DNS name - use cert",
   498  	info: &api.Info{
   499  		Addrs:       []string{"0.1.1.1:1234"},
   500  		SNIHostName: "foo.com",
   501  		SkipLogin:   true,
   502  		CACert:      jtesting.CACert,
   503  	},
   504  	expectDial: apiDialInfo{
   505  		location:   "wss://0.1.1.1:1234/api",
   506  		hasRootCAs: true,
   507  		serverName: "juju-apiserver",
   508  	},
   509  }, {
   510  	about: "with cert; numeric IP address - use cert",
   511  	info: &api.Info{
   512  		Addrs:       []string{"0.1.2.3:1234"},
   513  		SNIHostName: "foo.com",
   514  		SkipLogin:   true,
   515  		CACert:      jtesting.CACert,
   516  	},
   517  	expectDial: apiDialInfo{
   518  		location:   "wss://0.1.2.3:1234/api",
   519  		hasRootCAs: true,
   520  		serverName: "juju-apiserver",
   521  	},
   522  }}
   523  
   524  func (s *apiclientSuite) TestOpenWithSNIHostname(c *gc.C) {
   525  	for i, test := range openWithSNIHostnameTests {
   526  		c.Logf("test %d: %v", i, test.about)
   527  		s.testOpenDialError(c, dialTest{
   528  			apiInfo:         test.info,
   529  			expectOpenError: `unable to connect to API: nope`,
   530  			expectDials: []dialAttempt{{
   531  				check: func(info dialInfo) {
   532  					c.Check(info.location, gc.Equals, test.expectDial.location)
   533  					c.Assert(info.tlsConfig, gc.NotNil)
   534  					c.Check(info.tlsConfig.RootCAs != nil, gc.Equals, test.expectDial.hasRootCAs)
   535  					c.Check(info.tlsConfig.ServerName, gc.Equals, test.expectDial.serverName)
   536  				},
   537  				returnError: errors.New("nope"),
   538  			}},
   539  			allowMoreDials: true,
   540  		})
   541  	}
   542  }
   543  
   544  func (s *apiclientSuite) TestFallbackToSNIHostnameOnCertErrorAndNonNumericHostname(c *gc.C) {
   545  	s.testOpenDialError(c, dialTest{
   546  		apiInfo: &api.Info{
   547  			Addrs:       []string{"x.com:1234"},
   548  			CACert:      jtesting.CACert,
   549  			SNIHostName: "foo.com",
   550  		},
   551  		// go 1.9 says "is not authorized to sign for this name"
   552  		// go 1.10 says "is not authorized to sign for this domain"
   553  		expectOpenError: `unable to connect to API: x509: a root or intermediate certificate is not authorized to sign.*`,
   554  		expectDials: []dialAttempt{{
   555  			// The first dial attempt should use the private CA cert.
   556  			check: func(info dialInfo) {
   557  				c.Assert(info.tlsConfig, gc.NotNil)
   558  				c.Check(info.tlsConfig.RootCAs.Subjects(), gc.HasLen, 1)
   559  				c.Check(info.tlsConfig.ServerName, gc.Equals, "juju-apiserver")
   560  			},
   561  			returnError: x509.CertificateInvalidError{
   562  				Reason: x509.CANotAuthorizedForThisName,
   563  			},
   564  		}, {
   565  			// The second dial attempt should fall back to using the
   566  			// SNI hostname.
   567  			check: func(info dialInfo) {
   568  				c.Assert(info.tlsConfig, gc.NotNil)
   569  				c.Check(info.tlsConfig.RootCAs, gc.IsNil)
   570  				c.Check(info.tlsConfig.ServerName, gc.Equals, "foo.com")
   571  			},
   572  			// Note: we return another certificate error so that
   573  			// the Open logic returns immediately rather than waiting
   574  			// for the timeout.
   575  			returnError: x509.SystemRootsError{},
   576  		}},
   577  	})
   578  }
   579  
   580  func (s *apiclientSuite) TestFailImmediatelyOnCertErrorAndNumericHostname(c *gc.C) {
   581  	s.testOpenDialError(c, dialTest{
   582  		apiInfo: &api.Info{
   583  			Addrs:  []string{"0.1.2.3:1234"},
   584  			CACert: jtesting.CACert,
   585  		},
   586  		// go 1.9 says "is not authorized to sign for this name"
   587  		// go 1.10 says "is not authorized to sign for this domain"
   588  		expectOpenError: `unable to connect to API: x509: a root or intermediate certificate is not authorized to sign.*`,
   589  		expectDials: []dialAttempt{{
   590  			// The first dial attempt should use the private CA cert.
   591  			check: func(info dialInfo) {
   592  				c.Assert(info.tlsConfig, gc.NotNil)
   593  				c.Check(info.tlsConfig.RootCAs.Subjects(), gc.HasLen, 1)
   594  				c.Check(info.tlsConfig.ServerName, gc.Equals, "juju-apiserver")
   595  			},
   596  			returnError: x509.CertificateInvalidError{
   597  				Reason: x509.CANotAuthorizedForThisName,
   598  			},
   599  		}},
   600  	})
   601  }
   602  
   603  type dialTest struct {
   604  	apiInfo *api.Info
   605  	// expectDials holds an entry for each dial
   606  	// attempt that's expected to be made.
   607  	// If allowMoreDials is true, any number of
   608  	// attempts will be allowed and the last entry
   609  	// of expectDials will be used when the
   610  	// number exceeds
   611  	expectDials     []dialAttempt
   612  	allowMoreDials  bool
   613  	expectOpenError string
   614  }
   615  
   616  type dialAttempt struct {
   617  	check       func(info dialInfo)
   618  	returnError error
   619  }
   620  
   621  type dialInfo struct {
   622  	location  string
   623  	tlsConfig *tls.Config
   624  	errc      chan<- error
   625  }
   626  
   627  func (s *apiclientSuite) testOpenDialError(c *gc.C, t dialTest) {
   628  	dialed := make(chan dialInfo)
   629  	fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) {
   630  		reply := make(chan error)
   631  		dialed <- dialInfo{
   632  			location:  urlStr,
   633  			tlsConfig: tlsConfig,
   634  			errc:      reply,
   635  		}
   636  		return nil, <-reply
   637  	}
   638  	done := make(chan struct{})
   639  	go func() {
   640  		defer close(done)
   641  		conn, err := api.Open(t.apiInfo, api.DialOpts{
   642  			DialWebsocket:  fakeDialer,
   643  			IPAddrResolver: seqResolver(t.apiInfo.Addrs...),
   644  			Clock:          &fakeClock{},
   645  		})
   646  		c.Check(conn, gc.Equals, nil)
   647  		c.Check(err, gc.ErrorMatches, t.expectOpenError)
   648  	}()
   649  	for i := 0; t.allowMoreDials || i < len(t.expectDials); i++ {
   650  		c.Logf("attempt %d", i)
   651  		var attempt dialAttempt
   652  		if i < len(t.expectDials) {
   653  			attempt = t.expectDials[i]
   654  		} else if t.allowMoreDials {
   655  			attempt = t.expectDials[len(t.expectDials)-1]
   656  		} else {
   657  			break
   658  		}
   659  		select {
   660  		case info := <-dialed:
   661  			attempt.check(info)
   662  			info.errc <- attempt.returnError
   663  		case <-done:
   664  			if i < len(t.expectDials) {
   665  				c.Fatalf("Open returned early - expected dials not made")
   666  			}
   667  			return
   668  		case <-time.After(jtesting.LongWait):
   669  			c.Fatalf("timed out waiting for dial")
   670  		}
   671  	}
   672  	select {
   673  	case <-done:
   674  	case <-time.After(jtesting.LongWait):
   675  		c.Fatalf("timed out waiting for API open")
   676  	}
   677  }
   678  
   679  func (s *apiclientSuite) TestOpenWithNoCACert(c *gc.C) {
   680  	// This is hard to test as we have no way of affecting the system roots,
   681  	// so instead we check that the error that we get implies that
   682  	// we're using the system roots.
   683  
   684  	info := s.APIInfo(c)
   685  	info.CACert = ""
   686  
   687  	// This test used to use a long timeout so that we can check that the retry
   688  	// logic doesn't retry, but that got all messed up with dualstack IPs.
   689  	// The api server was only listening on IPv4, but localhost resolved to both
   690  	// IPv4 and IPv6. The IPv4 didn't retry, but the IPv6 one did, because it was
   691  	// retrying the dial. The parallel try doesn't have a fatal error type yet.
   692  	_, err := api.Open(info, api.DialOpts{
   693  		Timeout:    2 * time.Second,
   694  		RetryDelay: 200 * time.Millisecond,
   695  	})
   696  	c.Assert(err, gc.ErrorMatches, `unable to connect to API:.*x509: certificate signed by unknown authority`)
   697  }
   698  
   699  func (s *apiclientSuite) TestOpenWithRedirect(c *gc.C) {
   700  	redirectToHosts := []string{"0.1.2.3:1234", "0.1.2.4:1235"}
   701  	redirectToCACert := "fake CA cert"
   702  
   703  	srv := apiservertesting.NewAPIServer(func(modelUUID string) interface{} {
   704  		return &redirectAPI{
   705  			modelUUID:        modelUUID,
   706  			redirectToHosts:  redirectToHosts,
   707  			redirectToCACert: redirectToCACert,
   708  		}
   709  	})
   710  	defer srv.Close()
   711  
   712  	_, err := api.Open(&api.Info{
   713  		Addrs:    srv.Addrs,
   714  		CACert:   jtesting.CACert,
   715  		ModelTag: names.NewModelTag("beef1beef1-0000-0000-000011112222"),
   716  	}, api.DialOpts{})
   717  	c.Assert(err, gc.ErrorMatches, `redirection to alternative server required`)
   718  
   719  	hps := make(network.MachineHostPorts, len(redirectToHosts))
   720  	for i, addr := range redirectToHosts {
   721  		hp, err := network.ParseMachineHostPort(addr)
   722  		c.Assert(err, jc.ErrorIsNil)
   723  		hps[i] = *hp
   724  	}
   725  
   726  	c.Assert(errors.Cause(err), jc.DeepEquals, &api.RedirectError{
   727  		Servers:        []network.MachineHostPorts{hps},
   728  		CACert:         redirectToCACert,
   729  		FollowRedirect: true,
   730  	})
   731  }
   732  
   733  func (s *apiclientSuite) TestOpenCachesDNS(c *gc.C) {
   734  	fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) {
   735  		return fakeConn{}, nil
   736  	}
   737  	dnsCache := make(dnsCacheMap)
   738  	conn, err := api.Open(&api.Info{
   739  		Addrs: []string{
   740  			"place1.example:1234",
   741  		},
   742  		SkipLogin: true,
   743  		CACert:    jtesting.CACert,
   744  	}, api.DialOpts{
   745  		DialWebsocket: fakeDialer,
   746  		IPAddrResolver: apitesting.IPAddrResolverMap{
   747  			"place1.example": {"0.1.1.1"},
   748  		},
   749  		DNSCache: dnsCache,
   750  	})
   751  	c.Assert(err, jc.ErrorIsNil)
   752  	c.Assert(conn, gc.NotNil)
   753  	c.Assert(dnsCache.Lookup("place1.example"), jc.DeepEquals, []string{"0.1.1.1"})
   754  }
   755  
   756  func (s *apiclientSuite) TestDNSCacheUsed(c *gc.C) {
   757  	var dialed string
   758  	fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) {
   759  		dialed = ipAddr
   760  		return fakeConn{}, nil
   761  	}
   762  	conn, err := api.Open(&api.Info{
   763  		Addrs: []string{
   764  			"place1.example:1234",
   765  		},
   766  		SkipLogin: true,
   767  		CACert:    jtesting.CACert,
   768  	}, api.DialOpts{
   769  		DialWebsocket: fakeDialer,
   770  		// Note: don't resolve any addresses. If we resolve one,
   771  		// then there's a possibility that the resolving will
   772  		// happen and a second dial attempt will happen before
   773  		// the Open returns, giving rise to a race.
   774  		IPAddrResolver: apitesting.IPAddrResolverMap{},
   775  		DNSCache: dnsCacheMap{
   776  			"place1.example": {"0.1.1.1"},
   777  		},
   778  	})
   779  	c.Assert(err, jc.ErrorIsNil)
   780  	c.Assert(conn, gc.NotNil)
   781  	// The dialed IP address should have come from the cache, not the IP address
   782  	// resolver.
   783  	c.Assert(dialed, gc.Equals, "0.1.1.1:1234")
   784  	c.Assert(conn.IPAddr(), gc.Equals, "0.1.1.1:1234")
   785  }
   786  
   787  func (s *apiclientSuite) TestNumericAddressIsNotAddedToCache(c *gc.C) {
   788  	fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) {
   789  		return fakeConn{}, nil
   790  	}
   791  	dnsCache := make(dnsCacheMap)
   792  	conn, err := api.Open(&api.Info{
   793  		Addrs: []string{
   794  			"0.1.2.3:1234",
   795  		},
   796  		SkipLogin: true,
   797  		CACert:    jtesting.CACert,
   798  	}, api.DialOpts{
   799  		DialWebsocket:  fakeDialer,
   800  		IPAddrResolver: apitesting.IPAddrResolverMap{},
   801  		DNSCache:       dnsCache,
   802  	})
   803  	c.Assert(err, jc.ErrorIsNil)
   804  	c.Assert(conn, gc.NotNil)
   805  	c.Assert(conn.Addr(), gc.Equals, "0.1.2.3:1234")
   806  	c.Assert(conn.IPAddr(), gc.Equals, "0.1.2.3:1234")
   807  	c.Assert(dnsCache, gc.HasLen, 0)
   808  }
   809  
   810  func (s *apiclientSuite) TestFallbackToIPLookupWhenCacheOutOfDate(c *gc.C) {
   811  	dialc := make(chan string)
   812  	start := make(chan struct{})
   813  	fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) {
   814  		dialc <- ipAddr
   815  		<-start
   816  		if ipAddr == "0.2.2.2:1234" {
   817  			return fakeConn{}, nil
   818  		}
   819  		return nil, errors.Errorf("bad address")
   820  	}
   821  	dnsCache := dnsCacheMap{
   822  		"place1.example": {"0.1.1.1"},
   823  	}
   824  	type openResult struct {
   825  		conn api.Connection
   826  		err  error
   827  	}
   828  	openc := make(chan openResult)
   829  	go func() {
   830  		conn, err := api.Open(&api.Info{
   831  			Addrs: []string{
   832  				"place1.example:1234",
   833  			},
   834  			SkipLogin: true,
   835  			CACert:    jtesting.CACert,
   836  		}, api.DialOpts{
   837  			// Note: zero timeout means each address attempt
   838  			// will only try once only.
   839  			DialWebsocket: fakeDialer,
   840  			IPAddrResolver: apitesting.IPAddrResolverMap{
   841  				"place1.example": {"0.2.2.2"},
   842  			},
   843  			DNSCache: dnsCache,
   844  		})
   845  		openc <- openResult{conn, err}
   846  	}()
   847  	// Wait for both dial attempts to happen.
   848  	// If we don't, then the second attempt might
   849  	// happen before the first one and the first
   850  	// attempt might then never happen.
   851  	dialed := make(map[string]bool)
   852  	for i := 0; i < 2; i++ {
   853  		select {
   854  		case hostPort := <-dialc:
   855  			dialed[hostPort] = true
   856  		case <-time.After(jtesting.LongWait):
   857  			c.Fatalf("timed out waiting for dial attempt")
   858  		}
   859  	}
   860  	// Allow the dial attempts to return.
   861  	close(start)
   862  	// Check that no more dial attempts happen.
   863  	select {
   864  	case hostPort := <-dialc:
   865  		c.Fatalf("unexpected dial attempt to %q; existing attempts: %v", hostPort, dialed)
   866  	case <-time.After(jtesting.ShortWait):
   867  	}
   868  	r := <-openc
   869  	c.Assert(r.err, jc.ErrorIsNil)
   870  	c.Assert(r.conn, gc.NotNil)
   871  	c.Assert(r.conn.Addr(), gc.Equals, "place1.example:1234")
   872  	c.Assert(r.conn.IPAddr(), gc.Equals, "0.2.2.2:1234")
   873  	c.Assert(dialed, jc.DeepEquals, map[string]bool{
   874  		"0.2.2.2:1234": true,
   875  		"0.1.1.1:1234": true,
   876  	})
   877  	c.Assert(dnsCache.Lookup("place1.example"), jc.DeepEquals, []string{"0.2.2.2"})
   878  }
   879  
   880  func (s *apiclientSuite) TestOpenTimesOutOnLogin(c *gc.C) {
   881  	unblock := make(chan chan struct{})
   882  	srv := apiservertesting.NewAPIServer(func(modelUUID string) interface{} {
   883  		return &loginTimeoutAPI{
   884  			unblock: unblock,
   885  		}
   886  	})
   887  	defer srv.Close()
   888  	defer close(unblock)
   889  
   890  	clk := testclock.NewClock(time.Now())
   891  	done := make(chan error, 1)
   892  	go func() {
   893  		_, err := api.Open(&api.Info{
   894  			Addrs:    srv.Addrs,
   895  			CACert:   jtesting.CACert,
   896  			ModelTag: names.NewModelTag("beef1beef1-0000-0000-000011112222"),
   897  		}, api.DialOpts{
   898  			Clock:   clk,
   899  			Timeout: 5 * time.Second,
   900  		})
   901  		done <- err
   902  	}()
   903  	// Wait for Login to be entered before we advance the clock. Note that we don't actually unblock the request,
   904  	// we just ensure that the other side has gotten to the point where it wants to be blocked. Otherwise we might
   905  	// advance the clock before we even get the api.Dial to finish or before TLS handshaking finishes.
   906  	unblocked := make(chan struct{})
   907  	defer close(unblocked)
   908  	select {
   909  	case unblock <- unblocked:
   910  	case <-time.After(jtesting.LongWait):
   911  		c.Fatalf("timed out waiting for Login to be called")
   912  	}
   913  	err := clk.WaitAdvance(5*time.Second, time.Second, 1)
   914  	c.Assert(err, jc.ErrorIsNil)
   915  	select {
   916  	case err := <-done:
   917  		c.Assert(err, gc.ErrorMatches, `cannot log in: context deadline exceeded`)
   918  	case <-time.After(time.Second):
   919  		c.Fatalf("timed out waiting for api.Open timeout")
   920  	}
   921  }
   922  
   923  func (s *apiclientSuite) TestOpenTimeoutAffectsDial(c *gc.C) {
   924  	sync := make(chan struct{})
   925  	fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) {
   926  		close(sync)
   927  		<-ctx.Done()
   928  		return nil, ctx.Err()
   929  	}
   930  
   931  	clk := testclock.NewClock(time.Now())
   932  	done := make(chan error, 1)
   933  	go func() {
   934  		_, err := api.Open(&api.Info{
   935  			Addrs:     []string{"127.0.0.1:1234"},
   936  			CACert:    jtesting.CACert,
   937  			ModelTag:  names.NewModelTag("beef1beef1-0000-0000-000011112222"),
   938  			SkipLogin: true,
   939  		}, api.DialOpts{
   940  			Clock:         clk,
   941  			Timeout:       5 * time.Second,
   942  			DialWebsocket: fakeDialer,
   943  		})
   944  		done <- err
   945  	}()
   946  	// Before we advance time, ensure that the parallel try mechanism
   947  	// has entered the dial function.
   948  	select {
   949  	case <-sync:
   950  	case <-time.After(testing.LongWait):
   951  		c.Errorf("didn't enter dial")
   952  	}
   953  	err := clk.WaitAdvance(5*time.Second, time.Second, 1)
   954  	c.Assert(err, jc.ErrorIsNil)
   955  	select {
   956  	case err := <-done:
   957  		c.Assert(err, gc.ErrorMatches, `unable to connect to API: context deadline exceeded`)
   958  	case <-time.After(time.Second):
   959  		c.Fatalf("timed out waiting for api.Open timeout")
   960  	}
   961  }
   962  
   963  func (s *apiclientSuite) TestOpenDialTimeoutAffectsDial(c *gc.C) {
   964  	sync := make(chan struct{})
   965  	fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) {
   966  		close(sync)
   967  		<-ctx.Done()
   968  		return nil, ctx.Err()
   969  	}
   970  
   971  	clk := testclock.NewClock(time.Now())
   972  	done := make(chan error, 1)
   973  	go func() {
   974  		_, err := api.Open(&api.Info{
   975  			Addrs:     []string{"127.0.0.1:1234"},
   976  			CACert:    jtesting.CACert,
   977  			ModelTag:  names.NewModelTag("beef1beef1-0000-0000-000011112222"),
   978  			SkipLogin: true,
   979  		}, api.DialOpts{
   980  			Clock:         clk,
   981  			Timeout:       5 * time.Second,
   982  			DialTimeout:   3 * time.Second,
   983  			DialWebsocket: fakeDialer,
   984  		})
   985  		done <- err
   986  	}()
   987  	// Before we advance time, ensure that the parallel try mechanism
   988  	// has entered the dial function.
   989  	select {
   990  	case <-sync:
   991  	case <-time.After(testing.LongWait):
   992  		c.Errorf("didn't enter dial")
   993  	}
   994  	err := clk.WaitAdvance(3*time.Second, time.Second, 2) // Timeout & DialTimeout
   995  	c.Assert(err, jc.ErrorIsNil)
   996  	select {
   997  	case err := <-done:
   998  		c.Assert(err, gc.ErrorMatches, `unable to connect to API: context deadline exceeded`)
   999  	case <-time.After(time.Second):
  1000  		c.Fatalf("timed out waiting for api.Open timeout")
  1001  	}
  1002  }
  1003  
  1004  func (s *apiclientSuite) TestOpenDialTimeoutDoesNotAffectLogin(c *gc.C) {
  1005  	unblock := make(chan chan struct{})
  1006  	srv := apiservertesting.NewAPIServer(func(modelUUID string) interface{} {
  1007  		return &loginTimeoutAPI{
  1008  			unblock: unblock,
  1009  		}
  1010  	})
  1011  	defer srv.Close()
  1012  	defer close(unblock)
  1013  
  1014  	clk := testclock.NewClock(time.Now())
  1015  	done := make(chan error, 1)
  1016  	go func() {
  1017  		_, err := api.Open(&api.Info{
  1018  			Addrs:    srv.Addrs,
  1019  			CACert:   jtesting.CACert,
  1020  			ModelTag: names.NewModelTag("beef1beef1-0000-0000-000011112222"),
  1021  		}, api.DialOpts{
  1022  			Clock:       clk,
  1023  			DialTimeout: 5 * time.Second,
  1024  		})
  1025  		done <- err
  1026  	}()
  1027  
  1028  	// We should not get a response from api.Open until we
  1029  	// unblock the login.
  1030  	unblocked := make(chan struct{})
  1031  	select {
  1032  	case unblock <- unblocked:
  1033  		// We are now in the Login method of the loginTimeoutAPI.
  1034  	case <-time.After(jtesting.LongWait):
  1035  		c.Fatalf("didn't enter Login")
  1036  	}
  1037  
  1038  	// There should be nothing waiting. Advance the clock to where it
  1039  	// would have triggered the DialTimeout. But this doesn't stop api.Open
  1040  	// as we have already connected and entered Login.
  1041  	err := clk.WaitAdvance(5*time.Second, 0, 0)
  1042  	c.Assert(err, jc.ErrorIsNil)
  1043  
  1044  	// Ensure that api.Open doesn't return until we tell it to.
  1045  	select {
  1046  	case <-done:
  1047  		c.Fatalf("unexpected return from api.Open")
  1048  	case <-time.After(jtesting.ShortWait):
  1049  	}
  1050  
  1051  	// unblock the login by sending to "unblocked", and then the
  1052  	// api.Open should return the result of the login.
  1053  	close(unblocked)
  1054  	select {
  1055  	case err := <-done:
  1056  		c.Assert(err, gc.ErrorMatches, "login failed")
  1057  	case <-time.After(jtesting.LongWait):
  1058  		c.Fatalf("timed out waiting for api.Open to return")
  1059  	}
  1060  }
  1061  
  1062  func (s *apiclientSuite) TestWithUnresolvableAddr(c *gc.C) {
  1063  	fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) {
  1064  		c.Errorf("dial was called but should not have been")
  1065  		return nil, errors.Errorf("cannot dial")
  1066  	}
  1067  	conn, err := api.Open(&api.Info{
  1068  		Addrs: []string{
  1069  			"nowhere.example:1234",
  1070  		},
  1071  		SkipLogin: true,
  1072  		CACert:    jtesting.CACert,
  1073  	}, api.DialOpts{
  1074  		DialWebsocket:  fakeDialer,
  1075  		IPAddrResolver: apitesting.IPAddrResolverMap{},
  1076  	})
  1077  	c.Assert(err, gc.ErrorMatches, `cannot resolve "nowhere.example": mock resolver cannot resolve "nowhere.example"`)
  1078  	c.Assert(conn, jc.ErrorIsNil)
  1079  }
  1080  
  1081  func (s *apiclientSuite) TestWithUnresolvableAddrAfterCacheFallback(c *gc.C) {
  1082  	var dialedReal bool
  1083  	fakeDialer := func(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) {
  1084  		if ipAddr == "0.2.2.2:1234" {
  1085  			dialedReal = true
  1086  			return nil, errors.Errorf("cannot connect with real address")
  1087  		}
  1088  		return nil, errors.Errorf("bad address from cache")
  1089  	}
  1090  	dnsCache := dnsCacheMap{
  1091  		"place1.example": {"0.1.1.1"},
  1092  	}
  1093  	conn, err := api.Open(&api.Info{
  1094  		Addrs: []string{
  1095  			"place1.example:1234",
  1096  		},
  1097  		SkipLogin: true,
  1098  		CACert:    jtesting.CACert,
  1099  	}, api.DialOpts{
  1100  		DialWebsocket: fakeDialer,
  1101  		IPAddrResolver: apitesting.IPAddrResolverMap{
  1102  			"place1.example": {"0.2.2.2"},
  1103  		},
  1104  		DNSCache: dnsCache,
  1105  	})
  1106  	c.Assert(err, gc.NotNil)
  1107  	c.Assert(conn, gc.Equals, nil)
  1108  	c.Assert(dnsCache.Lookup("place1.example"), jc.DeepEquals, []string{"0.2.2.2"})
  1109  	c.Assert(dialedReal, jc.IsTrue)
  1110  }
  1111  
  1112  func (s *apiclientSuite) TestAPICallNoError(c *gc.C) {
  1113  	clock := &fakeClock{}
  1114  	conn := api.NewTestingState(api.TestingStateParams{
  1115  		RPCConnection: newRPCConnection(),
  1116  		Clock:         clock,
  1117  	})
  1118  
  1119  	err := conn.APICall("facade", 1, "id", "method", nil, nil)
  1120  	c.Check(err, jc.ErrorIsNil)
  1121  	c.Check(clock.waits, gc.HasLen, 0)
  1122  }
  1123  
  1124  func (s *apiclientSuite) TestAPICallError(c *gc.C) {
  1125  	clock := &fakeClock{}
  1126  	conn := api.NewTestingState(api.TestingStateParams{
  1127  		RPCConnection: newRPCConnection(errors.BadRequestf("boom")),
  1128  		Clock:         clock,
  1129  	})
  1130  
  1131  	err := conn.APICall("facade", 1, "id", "method", nil, nil)
  1132  	c.Check(err.Error(), gc.Equals, "boom")
  1133  	c.Check(err, jc.Satisfies, errors.IsBadRequest)
  1134  	c.Check(clock.waits, gc.HasLen, 0)
  1135  }
  1136  
  1137  func (s *apiclientSuite) TestIsBrokenOk(c *gc.C) {
  1138  	conn := api.NewTestingState(api.TestingStateParams{
  1139  		RPCConnection: newRPCConnection(),
  1140  		Clock:         new(fakeClock),
  1141  	})
  1142  	c.Assert(conn.IsBroken(), jc.IsFalse)
  1143  }
  1144  
  1145  func (s *apiclientSuite) TestIsBrokenChannelClosed(c *gc.C) {
  1146  	broken := make(chan struct{})
  1147  	close(broken)
  1148  	conn := api.NewTestingState(api.TestingStateParams{
  1149  		RPCConnection: newRPCConnection(),
  1150  		Clock:         new(fakeClock),
  1151  		Broken:        broken,
  1152  	})
  1153  	c.Assert(conn.IsBroken(), jc.IsTrue)
  1154  }
  1155  
  1156  func (s *apiclientSuite) TestIsBrokenPingFailed(c *gc.C) {
  1157  	conn := api.NewTestingState(api.TestingStateParams{
  1158  		RPCConnection: newRPCConnection(errors.New("no biscuit")),
  1159  		Clock:         new(fakeClock),
  1160  	})
  1161  	c.Assert(conn.IsBroken(), jc.IsTrue)
  1162  }
  1163  
  1164  func (s *apiclientSuite) TestLoginCapturesCLIArgs(c *gc.C) {
  1165  	s.PatchValue(&os.Args, []string{"this", "is", "the test", "command"})
  1166  
  1167  	info := s.APIInfo(c)
  1168  	conn := newRPCConnection()
  1169  	conn.response = &params.LoginResult{
  1170  		ControllerTag: "controller-" + s.ControllerConfig.ControllerUUID(),
  1171  		ServerVersion: "2.3-rc2",
  1172  	}
  1173  	// Pass an already-closed channel so we don't wait for the monitor
  1174  	// to signal the rpc connection is dead when closing the state
  1175  	// (because there's no monitor running).
  1176  	broken := make(chan struct{})
  1177  	close(broken)
  1178  	testState := api.NewTestingState(api.TestingStateParams{
  1179  		RPCConnection: conn,
  1180  		Clock:         &fakeClock{},
  1181  		Address:       "localhost:1234",
  1182  		Broken:        broken,
  1183  		Closed:        make(chan struct{}),
  1184  	})
  1185  	err := testState.Login(info.Tag, info.Password, "", nil)
  1186  	c.Assert(err, jc.ErrorIsNil)
  1187  
  1188  	calls := conn.stub.Calls()
  1189  	c.Assert(calls, gc.HasLen, 1)
  1190  	call := calls[0]
  1191  	c.Assert(call.FuncName, gc.Equals, "Admin.Login")
  1192  	c.Assert(call.Args, gc.HasLen, 2)
  1193  	request := call.Args[1].(*params.LoginRequest)
  1194  	c.Assert(request.CLIArgs, gc.Equals, `this is "the test" command`)
  1195  }
  1196  
  1197  func (s *apiclientSuite) TestConnectStreamRequiresSlashPathPrefix(c *gc.C) {
  1198  	reader, err := s.APIState.ConnectStream("foo", nil)
  1199  	c.Assert(err, gc.ErrorMatches, `cannot make API path from non-slash-prefixed path "foo"`)
  1200  	c.Assert(reader, gc.Equals, nil)
  1201  }
  1202  
  1203  func (s *apiclientSuite) TestConnectStreamErrorBadConnection(c *gc.C) {
  1204  	s.PatchValue(&api.WebsocketDial, func(_ api.WebsocketDialer, _ string, _ http.Header) (base.Stream, error) {
  1205  		return nil, fmt.Errorf("bad connection")
  1206  	})
  1207  	reader, err := s.APIState.ConnectStream("/", nil)
  1208  	c.Assert(err, gc.ErrorMatches, "bad connection")
  1209  	c.Assert(reader, gc.IsNil)
  1210  }
  1211  
  1212  func (s *apiclientSuite) TestConnectStreamErrorNoData(c *gc.C) {
  1213  	s.PatchValue(&api.WebsocketDial, func(_ api.WebsocketDialer, _ string, _ http.Header) (base.Stream, error) {
  1214  		return api.NewFakeStreamReader(&bytes.Buffer{}), nil
  1215  	})
  1216  	reader, err := s.APIState.ConnectStream("/", nil)
  1217  	c.Assert(err, gc.ErrorMatches, "unable to read initial response: EOF")
  1218  	c.Assert(reader, gc.IsNil)
  1219  }
  1220  
  1221  func (s *apiclientSuite) TestConnectStreamErrorBadData(c *gc.C) {
  1222  	s.PatchValue(&api.WebsocketDial, func(_ api.WebsocketDialer, _ string, _ http.Header) (base.Stream, error) {
  1223  		return api.NewFakeStreamReader(strings.NewReader("junk\n")), nil
  1224  	})
  1225  	reader, err := s.APIState.ConnectStream("/", nil)
  1226  	c.Assert(err, gc.ErrorMatches, "unable to unmarshal initial response: .*")
  1227  	c.Assert(reader, gc.IsNil)
  1228  }
  1229  
  1230  func (s *apiclientSuite) TestConnectStreamErrorReadError(c *gc.C) {
  1231  	s.PatchValue(&api.WebsocketDial, func(_ api.WebsocketDialer, _ string, _ http.Header) (base.Stream, error) {
  1232  		err := fmt.Errorf("bad read")
  1233  		return api.NewFakeStreamReader(&badReader{err}), nil
  1234  	})
  1235  	reader, err := s.APIState.ConnectStream("/", nil)
  1236  	c.Assert(err, gc.ErrorMatches, "unable to read initial response: bad read")
  1237  	c.Assert(reader, gc.IsNil)
  1238  }
  1239  
  1240  // badReader raises err when Read is called.
  1241  type badReader struct {
  1242  	err error
  1243  }
  1244  
  1245  func (r *badReader) Read(p []byte) (n int, err error) {
  1246  	return 0, r.err
  1247  }
  1248  
  1249  func (s *apiclientSuite) TestConnectControllerStreamRejectsRelativePaths(c *gc.C) {
  1250  	reader, err := s.APIState.ConnectControllerStream("foo", nil, nil)
  1251  	c.Assert(err, gc.ErrorMatches, `path "foo" is not absolute`)
  1252  	c.Assert(reader, gc.IsNil)
  1253  }
  1254  
  1255  func (s *apiclientSuite) TestConnectControllerStreamRejectsModelPaths(c *gc.C) {
  1256  	reader, err := s.APIState.ConnectControllerStream("/model/foo", nil, nil)
  1257  	c.Assert(err, gc.ErrorMatches, `path "/model/foo" is model-specific`)
  1258  	c.Assert(reader, gc.IsNil)
  1259  }
  1260  
  1261  func (s *apiclientSuite) TestConnectControllerStreamAppliesHeaders(c *gc.C) {
  1262  	catcher := api.UrlCatcher{}
  1263  	headers := http.Header{}
  1264  	headers.Add("thomas", "cromwell")
  1265  	headers.Add("anne", "boleyn")
  1266  	s.PatchValue(&api.WebsocketDial, catcher.RecordLocation)
  1267  
  1268  	_, err := s.APIState.ConnectControllerStream("/something", nil, headers)
  1269  	c.Assert(err, jc.ErrorIsNil)
  1270  	c.Assert(catcher.Headers().Get("thomas"), gc.Equals, "cromwell")
  1271  	c.Assert(catcher.Headers().Get("anne"), gc.Equals, "boleyn")
  1272  }
  1273  
  1274  func (s *apiclientSuite) TestWatchDebugLogParamsEncoded(c *gc.C) {
  1275  	catcher := api.UrlCatcher{}
  1276  	s.PatchValue(&api.WebsocketDial, catcher.RecordLocation)
  1277  
  1278  	params := common.DebugLogParams{
  1279  		IncludeEntity: []string{"a", "b"},
  1280  		IncludeModule: []string{"c", "d"},
  1281  		IncludeLabel:  []string{"e", "f"},
  1282  		ExcludeEntity: []string{"g", "h"},
  1283  		ExcludeModule: []string{"i", "j"},
  1284  		ExcludeLabel:  []string{"k", "l"},
  1285  		Limit:         100,
  1286  		Backlog:       200,
  1287  		Level:         loggo.ERROR,
  1288  		Replay:        true,
  1289  		NoTail:        true,
  1290  		StartTime:     time.Date(2016, 11, 30, 11, 48, 0, 100, time.UTC),
  1291  	}
  1292  
  1293  	urlValues := url.Values{
  1294  		"includeEntity": params.IncludeEntity,
  1295  		"includeModule": params.IncludeModule,
  1296  		"includeLabel":  params.IncludeLabel,
  1297  		"excludeEntity": params.ExcludeEntity,
  1298  		"excludeModule": params.ExcludeModule,
  1299  		"excludeLabel":  params.ExcludeLabel,
  1300  		"maxLines":      {"100"},
  1301  		"backlog":       {"200"},
  1302  		"level":         {"ERROR"},
  1303  		"replay":        {"true"},
  1304  		"noTail":        {"true"},
  1305  		"startTime":     {"2016-11-30T11:48:00.0000001Z"},
  1306  	}
  1307  
  1308  	client := apiclient.NewClient(s.APIState, jtesting.NoopLogger{})
  1309  	_, err := client.WatchDebugLog(params)
  1310  	c.Assert(err, jc.ErrorIsNil)
  1311  
  1312  	connectURL, err := url.Parse(catcher.Location())
  1313  	c.Assert(err, jc.ErrorIsNil)
  1314  
  1315  	values := connectURL.Query()
  1316  	c.Assert(values, jc.DeepEquals, urlValues)
  1317  }
  1318  
  1319  func (s *apiclientSuite) TestWatchDebugLogConnected(c *gc.C) {
  1320  	cl := apiclient.NewClient(s.APIState, jtesting.NoopLogger{})
  1321  	// Use the no tail option so we don't try to start a tailing cursor
  1322  	// on the oplog when there is no oplog configured in mongo as the tests
  1323  	// don't set up mongo in replicaset mode.
  1324  	messages, err := cl.WatchDebugLog(common.DebugLogParams{NoTail: true})
  1325  	c.Assert(err, jc.ErrorIsNil)
  1326  	c.Assert(messages, gc.NotNil)
  1327  }
  1328  
  1329  func (s *apiclientSuite) TestConnectStreamAtUUIDPath(c *gc.C) {
  1330  	catcher := api.UrlCatcher{}
  1331  	s.PatchValue(&api.WebsocketDial, catcher.RecordLocation)
  1332  	model, err := s.State.Model()
  1333  	c.Assert(err, jc.ErrorIsNil)
  1334  	info := s.APIInfo(c)
  1335  	info.ModelTag = model.ModelTag()
  1336  	apistate, err := api.Open(info, api.DialOpts{})
  1337  	c.Assert(err, jc.ErrorIsNil)
  1338  	defer apistate.Close()
  1339  	_, err = apistate.ConnectStream("/path", nil)
  1340  	c.Assert(err, jc.ErrorIsNil)
  1341  	connectURL, err := url.Parse(catcher.Location())
  1342  	c.Assert(err, jc.ErrorIsNil)
  1343  	c.Assert(connectURL.Path, gc.Matches, fmt.Sprintf("/model/%s/path", model.UUID()))
  1344  }
  1345  
  1346  func (s *apiclientSuite) TestOpenUsesModelUUIDPaths(c *gc.C) {
  1347  	info := s.APIInfo(c)
  1348  
  1349  	// Passing in the correct model UUID should work
  1350  	model, err := s.State.Model()
  1351  	c.Assert(err, jc.ErrorIsNil)
  1352  	info.ModelTag = model.ModelTag()
  1353  	apistate, err := api.Open(info, api.DialOpts{})
  1354  	c.Assert(err, jc.ErrorIsNil)
  1355  	apistate.Close()
  1356  
  1357  	// Passing in an unknown model UUID should fail with a known error
  1358  	info.ModelTag = names.NewModelTag("1eaf1e55-70ad-face-b007-70ad57001999")
  1359  	apistate, err = api.Open(info, api.DialOpts{})
  1360  	c.Assert(errors.Cause(err), gc.DeepEquals, &rpc.RequestError{
  1361  		Message: `unknown model: "1eaf1e55-70ad-face-b007-70ad57001999"`,
  1362  		Code:    "model not found",
  1363  	})
  1364  	c.Check(err, jc.Satisfies, params.IsCodeModelNotFound)
  1365  	c.Assert(apistate, gc.IsNil)
  1366  }
  1367  
  1368  type clientDNSNameSuite struct {
  1369  	jjtesting.JujuConnSuite
  1370  }
  1371  
  1372  var _ = gc.Suite(&clientDNSNameSuite{})
  1373  
  1374  func (s *clientDNSNameSuite) SetUpTest(c *gc.C) {
  1375  	// Start an API server with a (non-working) autocert hostname,
  1376  	// so we can check that the PublicDNSName in the result goes
  1377  	// all the way through the layers.
  1378  	if s.ControllerConfigAttrs == nil {
  1379  		s.ControllerConfigAttrs = make(map[string]interface{})
  1380  	}
  1381  	s.ControllerConfigAttrs[controller.AutocertDNSNameKey] = "somewhere.example.com"
  1382  	s.JujuConnSuite.SetUpTest(c)
  1383  }
  1384  
  1385  func (s *clientDNSNameSuite) TestPublicDNSName(c *gc.C) {
  1386  	apiInfo := s.APIInfo(c)
  1387  	conn, err := api.Open(apiInfo, api.DialOpts{})
  1388  	c.Assert(err, gc.IsNil)
  1389  	c.Assert(conn.PublicDNSName(), gc.Equals, "somewhere.example.com")
  1390  }
  1391  
  1392  type fakeClock struct {
  1393  	clock.Clock
  1394  
  1395  	mu    sync.Mutex
  1396  	now   time.Time
  1397  	waits []time.Duration
  1398  }
  1399  
  1400  func (f *fakeClock) Now() time.Time {
  1401  	f.mu.Lock()
  1402  	defer f.mu.Unlock()
  1403  	if f.now.IsZero() {
  1404  		f.now = time.Now()
  1405  	}
  1406  	return f.now
  1407  }
  1408  
  1409  func (f *fakeClock) After(d time.Duration) <-chan time.Time {
  1410  	f.mu.Lock()
  1411  	defer f.mu.Unlock()
  1412  	f.waits = append(f.waits, d)
  1413  	f.now = f.now.Add(d)
  1414  	return time.After(0)
  1415  }
  1416  
  1417  func (f *fakeClock) NewTimer(d time.Duration) clock.Timer {
  1418  	panic("NewTimer called on fakeClock - perhaps because fakeClock can't be used with DialOpts.Timeout")
  1419  }
  1420  
  1421  func newRPCConnection(errs ...error) *fakeRPCConnection {
  1422  	conn := new(fakeRPCConnection)
  1423  	conn.stub.SetErrors(errs...)
  1424  	return conn
  1425  }
  1426  
  1427  type fakeRPCConnection struct {
  1428  	stub     testing.Stub
  1429  	response interface{}
  1430  }
  1431  
  1432  func (f *fakeRPCConnection) Dead() <-chan struct{} {
  1433  	return nil
  1434  }
  1435  
  1436  func (f *fakeRPCConnection) Close() error {
  1437  	return nil
  1438  }
  1439  
  1440  func (f *fakeRPCConnection) Call(req rpc.Request, params, response interface{}) error {
  1441  	f.stub.AddCall(req.Type+"."+req.Action, req.Version, params)
  1442  	if f.response != nil {
  1443  		rv := reflect.ValueOf(response)
  1444  		target := reflect.Indirect(rv)
  1445  		target.Set(reflect.Indirect(reflect.ValueOf(f.response)))
  1446  	}
  1447  	return f.stub.NextErr()
  1448  }
  1449  
  1450  type redirectAPI struct {
  1451  	redirected       bool
  1452  	modelUUID        string
  1453  	redirectToHosts  []string
  1454  	redirectToCACert string
  1455  }
  1456  
  1457  func (r *redirectAPI) Admin(id string) (*redirectAPIAdmin, error) {
  1458  	return &redirectAPIAdmin{r}, nil
  1459  }
  1460  
  1461  type redirectAPIAdmin struct {
  1462  	r *redirectAPI
  1463  }
  1464  
  1465  func (a *redirectAPIAdmin) Login(req params.LoginRequest) (params.LoginResult, error) {
  1466  	if a.r.modelUUID != "beef1beef1-0000-0000-000011112222" {
  1467  		return params.LoginResult{}, errors.New("logged into unexpected model")
  1468  	}
  1469  	a.r.redirected = true
  1470  	return params.LoginResult{}, params.Error{
  1471  		Message: "redirect",
  1472  		Code:    params.CodeRedirect,
  1473  	}
  1474  }
  1475  
  1476  func (a *redirectAPIAdmin) RedirectInfo() (params.RedirectInfoResult, error) {
  1477  	if !a.r.redirected {
  1478  		return params.RedirectInfoResult{}, errors.New("not redirected")
  1479  	}
  1480  
  1481  	hps, err := network.ParseProviderHostPorts(a.r.redirectToHosts...)
  1482  	if err != nil {
  1483  		panic(err)
  1484  	}
  1485  	return params.RedirectInfoResult{
  1486  		Servers: [][]params.HostPort{params.FromProviderHostPorts(hps)},
  1487  		CACert:  a.r.redirectToCACert,
  1488  	}, nil
  1489  }
  1490  
  1491  func assertConnAddrForModel(c *gc.C, location, addr, modelUUID string) {
  1492  	c.Assert(location, gc.Equals, "wss://"+addr+"/model/"+modelUUID+"/api")
  1493  }
  1494  
  1495  func assertConnAddrForRoot(c *gc.C, location, addr string) {
  1496  	c.Assert(location, gc.Matches, "wss://"+addr+"/api")
  1497  }
  1498  
  1499  type fakeConn struct {
  1500  	closed chan struct{}
  1501  }
  1502  
  1503  func (c fakeConn) Receive(x interface{}) error {
  1504  	return errors.New("no data available from fake connection")
  1505  }
  1506  
  1507  func (c fakeConn) Send(x interface{}) error {
  1508  	return errors.New("cannot write to fake connection")
  1509  }
  1510  
  1511  func (c fakeConn) Close() error {
  1512  	if c.closed != nil {
  1513  		close(c.closed)
  1514  	}
  1515  	return nil
  1516  }
  1517  
  1518  // seqResolver returns an implementation of
  1519  // IPAddrResolver that maps the given addresses
  1520  // to sequential IP addresses 0.1.1.1, 0.2.2.2, etc.
  1521  func seqResolver(addrs ...string) api.IPAddrResolver {
  1522  	r := make(apitesting.IPAddrResolverMap)
  1523  	for i, addr := range addrs {
  1524  		host, _, err := net.SplitHostPort(addr)
  1525  		if err != nil {
  1526  			panic(err)
  1527  		}
  1528  		r[host] = []string{fmt.Sprintf("0.%[1]d.%[1]d.%[1]d", i+1)}
  1529  	}
  1530  	return r
  1531  }
  1532  
  1533  type dnsCacheMap map[string][]string
  1534  
  1535  func (m dnsCacheMap) Lookup(host string) []string {
  1536  	return m[host]
  1537  }
  1538  
  1539  func (m dnsCacheMap) Add(host string, ips []string) {
  1540  	m[host] = append([]string{}, ips...)
  1541  }
  1542  
  1543  type loginTimeoutAPI struct {
  1544  	unblock chan chan struct{}
  1545  }
  1546  
  1547  func (r *loginTimeoutAPI) Admin(id string) (*loginTimeoutAPIAdmin, error) {
  1548  	return &loginTimeoutAPIAdmin{r}, nil
  1549  }
  1550  
  1551  type loginTimeoutAPIAdmin struct {
  1552  	r *loginTimeoutAPI
  1553  }
  1554  
  1555  func (a *loginTimeoutAPIAdmin) Login(req params.LoginRequest) (params.LoginResult, error) {
  1556  	var unblocked chan struct{}
  1557  	select {
  1558  	case ch, ok := <-a.r.unblock:
  1559  		if !ok {
  1560  			return params.LoginResult{}, errors.New("abort")
  1561  		}
  1562  		unblocked = ch
  1563  	case <-time.After(jtesting.LongWait):
  1564  		return params.LoginResult{}, errors.New("timed out waiting to be unblocked")
  1565  	}
  1566  	select {
  1567  	case <-unblocked:
  1568  	case <-time.After(jtesting.LongWait):
  1569  		return params.LoginResult{}, errors.New("timed out sending on unblocked channel")
  1570  	}
  1571  	return params.LoginResult{}, errors.Errorf("login failed")
  1572  }