github.com/juju/juju@v0.0.0-20240430160146-1752b71fcf00/api/apiclient.go (about) 1 // Copyright 2012-2015 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package api 5 6 import ( 7 "bufio" 8 "context" 9 "crypto/tls" 10 "crypto/x509" 11 "encoding/json" 12 "fmt" 13 "io" 14 "math/rand" 15 "net" 16 "net/http" 17 "net/url" 18 "strconv" 19 "strings" 20 "sync/atomic" 21 "time" 22 23 "github.com/go-macaroon-bakery/macaroon-bakery/v3/bakery" 24 "github.com/go-macaroon-bakery/macaroon-bakery/v3/httpbakery" 25 "github.com/gorilla/websocket" 26 "github.com/juju/clock" 27 "github.com/juju/errors" 28 jujuhttp "github.com/juju/http/v2" 29 "github.com/juju/loggo" 30 "github.com/juju/names/v5" 31 "github.com/juju/utils/v3" 32 "github.com/juju/utils/v3/parallel" 33 "gopkg.in/retry.v1" 34 35 "github.com/juju/juju/api/base" 36 "github.com/juju/juju/core/facades" 37 coremacaroon "github.com/juju/juju/core/macaroon" 38 "github.com/juju/juju/core/network" 39 jujuproxy "github.com/juju/juju/proxy" 40 "github.com/juju/juju/rpc" 41 "github.com/juju/juju/rpc/jsoncodec" 42 "github.com/juju/juju/rpc/params" 43 "github.com/juju/juju/utils/proxy" 44 jujuversion "github.com/juju/juju/version" 45 ) 46 47 // PingPeriod defines how often the internal connection health check 48 // will run. 49 const PingPeriod = 1 * time.Minute 50 51 // pingTimeout defines how long a health check can take before we 52 // consider it to have failed. 53 const pingTimeout = 30 * time.Second 54 55 // modelRoot is the prefix that all model API paths begin with. 56 const modelRoot = "/model/" 57 58 var logger = loggo.GetLogger("juju.api") 59 60 type rpcConnection interface { 61 Call(req rpc.Request, params, response interface{}) error 62 Dead() <-chan struct{} 63 Close() error 64 } 65 66 // RedirectError is returned from Open when the controller 67 // needs to inform the client that the model is hosted 68 // on a different set of API addresses. 69 type RedirectError struct { 70 // Servers holds the sets of addresses of the redirected 71 // servers. 72 Servers []network.MachineHostPorts 73 74 // CACert holds the certificate of the remote server. 75 CACert string 76 77 // FollowRedirect is set to true for cases like JAAS where the client 78 // needs to automatically follow the redirect to the new controller. 79 FollowRedirect bool 80 81 // ControllerTag uniquely identifies the controller being redirected to. 82 ControllerTag names.ControllerTag 83 84 // An optional alias for the controller the model got redirected to. 85 // It can be used by the client to present the user with a more 86 // meaningful juju login -c XYZ command 87 ControllerAlias string 88 } 89 90 func (e *RedirectError) Error() string { 91 return "redirection to alternative server required" 92 } 93 94 // Open establishes a connection to the API server using the Info 95 // given, returning a State instance which can be used to make API 96 // requests. 97 // 98 // If the model is hosted on a different server, Open 99 // will return an error with a *RedirectError cause 100 // holding the details of another server to connect to. 101 // 102 // See Connect for details of the connection mechanics. 103 func Open(info *Info, opts DialOpts) (Connection, error) { 104 if err := info.Validate(); err != nil { 105 return nil, errors.Annotate(err, "validating info for opening an API connection") 106 } 107 if opts.Clock == nil { 108 opts.Clock = clock.WallClock 109 } 110 ctx := context.Background() 111 dialCtx := ctx 112 if opts.Timeout > 0 { 113 ctx1, cancel := utils.ContextWithTimeout(dialCtx, opts.Clock, opts.Timeout) 114 defer cancel() 115 dialCtx = ctx1 116 } 117 118 dialResult, err := dialAPI(dialCtx, info, opts) 119 if err != nil { 120 return nil, errors.Trace(err) 121 } 122 123 client := rpc.NewConn(jsoncodec.New(dialResult.conn), nil) 124 client.Start(ctx) 125 126 bakeryClient := opts.BakeryClient 127 if bakeryClient == nil { 128 bakeryClient = httpbakery.NewClient() 129 } else { 130 // Make a copy of the bakery client and its HTTP client 131 c := *opts.BakeryClient 132 bakeryClient = &c 133 httpc := *bakeryClient.Client 134 bakeryClient.Client = &httpc 135 } 136 137 // Technically when there's no CACert, we don't need this 138 // machinery, because we could just use http.DefaultTransport 139 // for everything, but it's easier just to leave it in place. 140 bakeryClient.Client.Transport = &hostSwitchingTransport{ 141 primaryHost: dialResult.addr, 142 primary: jujuhttp.NewHTTPTLSTransport(jujuhttp.TransportConfig{ 143 TLSConfig: dialResult.tlsConfig, 144 }), 145 fallback: http.DefaultTransport, 146 } 147 148 host := PerferredHost(info) 149 if host == "" { 150 host = dialResult.addr 151 } 152 153 pingerFacadeVersions := facadeVersions["Pinger"] 154 if len(pingerFacadeVersions) == 0 { 155 return nil, errors.Errorf("pinger facade version is required") 156 } 157 158 loginProvider := opts.LoginProvider 159 // TODO (alesstimec, wallyworld): login provider should be constructed outside 160 // of this function and always passed in as part of dial opts. Also Info 161 // does not need to hold the authentication related data anymore. Until that 162 // is refactored we fall back to using the user-pass login provider 163 // with information from Info. 164 if loginProvider == nil { 165 loginProvider = NewUserpassLoginProvider(info.Tag, info.Password, info.Nonce, info.Macaroons, bakeryClient, CookieURLFromHost(host)) 166 } 167 168 st := &state{ 169 ctx: context.Background(), 170 client: client, 171 conn: dialResult.conn, 172 clock: opts.Clock, 173 addr: dialResult.addr, 174 ipAddr: dialResult.ipAddr, 175 cookieURL: CookieURLFromHost(host), 176 pingerFacadeVersion: pingerFacadeVersions[len(pingerFacadeVersions)-1], 177 serverScheme: "https", 178 serverRootAddress: dialResult.addr, 179 // We populate the username and password before 180 // login because, when doing HTTP requests, we'll want 181 // to use the same username and password for authenticating 182 // those. If login fails, we discard the connection. 183 tag: tagToString(info.Tag), 184 password: info.Password, 185 macaroons: info.Macaroons, 186 nonce: info.Nonce, 187 tlsConfig: dialResult.tlsConfig, 188 bakeryClient: bakeryClient, 189 modelTag: info.ModelTag, 190 proxier: dialResult.proxier, 191 } 192 if !info.SkipLogin { 193 if err := loginWithContext(dialCtx, st, loginProvider); err != nil { 194 dialResult.conn.Close() 195 return nil, errors.Trace(err) 196 } 197 } 198 199 st.broken = make(chan struct{}) 200 st.closed = make(chan struct{}) 201 202 go (&monitor{ 203 clock: opts.Clock, 204 ping: st.ping, 205 pingPeriod: PingPeriod, 206 pingTimeout: pingTimeout, 207 closed: st.closed, 208 dead: client.Dead(), 209 broken: st.broken, 210 }).run() 211 return st, nil 212 } 213 214 // CookieURLFromHost creates a url.URL from a given host. 215 func CookieURLFromHost(host string) *url.URL { 216 return &url.URL{ 217 Scheme: "https", 218 Host: host, 219 Path: "/", 220 } 221 } 222 223 // PerferredHost returns the SNI hostname or controller name for the cookie URL 224 // so that it is stable when used with a HA controller cluster. 225 func PerferredHost(info *Info) string { 226 if info == nil { 227 return "" 228 } 229 230 host := info.SNIHostName 231 if host == "" && info.ControllerUUID != "" { 232 host = info.ControllerUUID 233 } 234 return host 235 } 236 237 // loginWithContext wraps st.Login with code that terminates 238 // if the context is cancelled. 239 // TODO(rogpeppe) pass Context into Login (and all API calls) so 240 // that this becomes unnecessary. 241 func loginWithContext(ctx context.Context, st *state, loginProvider LoginProvider) error { 242 if loginProvider == nil { 243 return errors.New("login provider not specified") 244 } 245 246 result := make(chan error, 1) 247 go func() { 248 loginResult, err := loginProvider.Login(ctx, st) 249 if err != nil { 250 result <- err 251 return 252 } 253 254 result <- st.setLoginResult(loginResult) 255 }() 256 select { 257 case err := <-result: 258 return errors.Trace(err) 259 case <-ctx.Done(): 260 return errors.Annotatef(ctx.Err(), "cannot log in") 261 } 262 } 263 264 // hostSwitchingTransport provides an http.RoundTripper 265 // that chooses an actual RoundTripper to use 266 // depending on the destination host. 267 // 268 // This makes it possible to use a different set of root 269 // CAs for the API and all other hosts. 270 type hostSwitchingTransport struct { 271 primaryHost string 272 primary http.RoundTripper 273 fallback http.RoundTripper 274 } 275 276 // RoundTrip implements http.RoundTripper.RoundTrip. 277 func (t *hostSwitchingTransport) RoundTrip(req *http.Request) (*http.Response, error) { 278 if req.URL.Host == t.primaryHost { 279 return t.primary.RoundTrip(req) 280 } 281 return t.fallback.RoundTrip(req) 282 } 283 284 // Context returns the context associated with this state. 285 func (st *state) Context() context.Context { 286 return st.ctx 287 } 288 289 // ConnectStream implements StreamConnector.ConnectStream. The stream 290 // returned will apply a 30-second write deadline, so WriteJSON should 291 // only be called from one goroutine. 292 func (st *state) ConnectStream(path string, attrs url.Values) (base.Stream, error) { 293 path, err := apiPath(st.modelTag.Id(), path) 294 if err != nil { 295 return nil, errors.Trace(err) 296 } 297 conn, err := st.connectStreamWithRetry(path, attrs, nil) 298 if err != nil { 299 return nil, errors.Trace(err) 300 } 301 return conn, nil 302 } 303 304 // ConnectControllerStream creates a stream connection to an API path 305 // that isn't prefixed with /model/uuid - the target model (if the 306 // endpoint needs one) can be specified in the headers. The stream 307 // returned will apply a 30-second write deadline, so WriteJSON should 308 // only be called from one goroutine. 309 func (st *state) ConnectControllerStream(path string, attrs url.Values, headers http.Header) (base.Stream, error) { 310 if !strings.HasPrefix(path, "/") { 311 return nil, errors.Errorf("path %q is not absolute", path) 312 } 313 if strings.HasPrefix(path, modelRoot) { 314 return nil, errors.Errorf("path %q is model-specific", path) 315 } 316 conn, err := st.connectStreamWithRetry(path, attrs, headers) 317 if err != nil { 318 return nil, errors.Trace(err) 319 } 320 return conn, nil 321 } 322 323 func (st *state) connectStreamWithRetry(path string, attrs url.Values, headers http.Header) (base.Stream, error) { 324 if !st.isLoggedIn() { 325 return nil, errors.New("cannot use ConnectStream without logging in") 326 } 327 // We use the standard "macaraq" macaroon authentication dance here. 328 // That is, we attach any macaroons we have to the initial request, 329 // and if that succeeds, all's good. If it fails with a DischargeRequired 330 // error, the response will contain a macaroon that, when discharged, 331 // may allow access, so we discharge it (using bakery.Client.HandleError) 332 // and try the request again. 333 conn, err := st.connectStream(path, attrs, headers) 334 if err == nil { 335 return conn, err 336 } 337 if params.ErrCode(err) != params.CodeDischargeRequired { 338 return nil, errors.Trace(err) 339 } 340 if err := st.bakeryClient.HandleError(st.ctx, st.cookieURL, bakeryError(err)); err != nil { 341 return nil, errors.Trace(err) 342 } 343 // Try again with the discharged macaroon. 344 conn, err = st.connectStream(path, attrs, headers) 345 if err != nil { 346 return nil, errors.Trace(err) 347 } 348 return conn, nil 349 } 350 351 // connectStream is the internal version of ConnectStream. It differs from 352 // ConnectStream only in that it will not retry the connection if it encounters 353 // discharge-required error. 354 func (st *state) connectStream(path string, attrs url.Values, extraHeaders http.Header) (base.Stream, error) { 355 target := url.URL{ 356 Scheme: "wss", 357 Host: st.addr, 358 Path: path, 359 RawQuery: attrs.Encode(), 360 } 361 // TODO(macgreagoir) IPv6. Ubuntu still always provides IPv4 loopback, 362 // and when/if this changes localhost should resolve to IPv6 loopback 363 // in any case (lp:1644009). Review. 364 365 dialer := &websocket.Dialer{ 366 Proxy: proxy.DefaultConfig.GetProxy, 367 TLSClientConfig: st.tlsConfig, 368 } 369 var requestHeader http.Header 370 if st.tag != "" { 371 requestHeader = jujuhttp.BasicAuthHeader(st.tag, st.password) 372 } else { 373 requestHeader = make(http.Header) 374 } 375 requestHeader.Set(params.JujuClientVersion, jujuversion.Current.String()) 376 requestHeader.Set("Origin", "http://localhost/") 377 if st.nonce != "" { 378 requestHeader.Set(params.MachineNonceHeader, st.nonce) 379 } 380 // Add any cookies because they will not be sent to websocket 381 // connections by default. 382 err := st.addCookiesToHeader(requestHeader) 383 if err != nil { 384 return nil, errors.Trace(err) 385 } 386 for header, values := range extraHeaders { 387 for _, value := range values { 388 requestHeader.Add(header, value) 389 } 390 } 391 392 connection, err := WebsocketDial(dialer, target.String(), requestHeader) 393 if err != nil { 394 return nil, err 395 } 396 if err := readInitialStreamError(connection); err != nil { 397 connection.Close() 398 return nil, errors.Trace(err) 399 } 400 return connection, nil 401 } 402 403 // readInitialStreamError reads the initial error response 404 // from a stream connection and returns it. 405 func readInitialStreamError(ws base.Stream) error { 406 // We can use bufio here because the websocket guarantees that a 407 // single read will not read more than a single frame; there is 408 // no guarantee that a single read might not read less than the 409 // whole frame though, so using a single Read call is not 410 // correct. By using ReadSlice rather than ReadBytes, we 411 // guarantee that the error can't be too big (>4096 bytes). 412 messageType, reader, err := ws.NextReader() 413 if err != nil { 414 return errors.Annotate(err, "unable to get reader") 415 } 416 if messageType != websocket.TextMessage { 417 return errors.Errorf("unexpected message type %v", messageType) 418 } 419 line, err := bufio.NewReader(reader).ReadSlice('\n') 420 if err != nil { 421 return errors.Annotate(err, "unable to read initial response") 422 } 423 var errResult params.ErrorResult 424 if err := json.Unmarshal(line, &errResult); err != nil { 425 return errors.Annotate(err, "unable to unmarshal initial response") 426 } 427 if errResult.Error != nil { 428 return errResult.Error 429 } 430 return nil 431 } 432 433 // addCookiesToHeader adds any cookies associated with the 434 // API host to the given header. This is necessary because 435 // otherwise cookies are not sent to websocket endpoints. 436 func (st *state) addCookiesToHeader(h http.Header) error { 437 // net/http only allows adding cookies to a request, 438 // but when it sends a request to a non-http endpoint, 439 // it doesn't add the cookies, so make a request, starting 440 // with the given header, add the cookies to use, then 441 // throw away the request but keep the header. 442 req := &http.Request{ 443 Header: h, 444 } 445 cookies := st.bakeryClient.Client.Jar.Cookies(st.cookieURL) 446 for _, c := range cookies { 447 req.AddCookie(c) 448 } 449 if len(cookies) == 0 && len(st.macaroons) > 0 { 450 // These macaroons must have been added directly rather than 451 // obtained from a request. Add them. (For example in the 452 // logtransfer connection for a migration.) 453 // See https://bugs.launchpad.net/juju/+bug/1650451 454 for _, macaroon := range st.macaroons { 455 cookie, err := httpbakery.NewCookie(coremacaroon.MacaroonNamespace, macaroon) 456 if err != nil { 457 return errors.Trace(err) 458 } 459 req.AddCookie(cookie) 460 } 461 } 462 h.Set(httpbakery.BakeryProtocolHeader, fmt.Sprint(bakery.LatestVersion)) 463 return nil 464 } 465 466 // apiEndpoint returns a URL that refers to the given API slash-prefixed 467 // endpoint path and query parameters. 468 func (st *state) apiEndpoint(path, query string) (*url.URL, error) { 469 path, err := apiPath(st.modelTag.Id(), path) 470 if err != nil { 471 return nil, errors.Trace(err) 472 } 473 return &url.URL{ 474 Scheme: st.serverScheme, 475 Host: st.Addr(), 476 Path: path, 477 RawQuery: query, 478 }, nil 479 } 480 481 // ControllerAPIURL returns the URL to use to connect to the controller API. 482 func ControllerAPIURL(addr string, port int) string { 483 hp := net.JoinHostPort(addr, strconv.Itoa(port)) 484 urlStr, _ := url.QueryUnescape(apiURL(hp, "").String()) 485 return urlStr 486 } 487 488 func apiURL(addr, model string) *url.URL { 489 path, _ := apiPath(model, "/api") 490 return &url.URL{ 491 Scheme: "wss", 492 Host: addr, 493 Path: path, 494 } 495 } 496 497 // ping implements calls the Pinger.ping facade. 498 func (s *state) ping() error { 499 return s.APICall("Pinger", s.pingerFacadeVersion, "", "Ping", nil, nil) 500 } 501 502 // apiPath returns the given API endpoint path relative 503 // to the given model string. 504 func apiPath(model, path string) (string, error) { 505 if !strings.HasPrefix(path, "/") { 506 return "", errors.Errorf("cannot make API path from non-slash-prefixed path %q", path) 507 } 508 if model == "" { 509 return path, nil 510 } 511 return modelRoot + model + path, nil 512 } 513 514 // tagToString returns the value of a tag's String method, or "" if the tag is nil. 515 func tagToString(tag names.Tag) string { 516 if tag == nil { 517 return "" 518 } 519 return tag.String() 520 } 521 522 // dialResult holds a dialed connection, the URL 523 // and TLS configuration used to connect to it. 524 type dialResult struct { 525 conn jsoncodec.JSONConn 526 addr string 527 urlStr string 528 ipAddr string 529 proxier jujuproxy.Proxier 530 tlsConfig *tls.Config 531 } 532 533 // Close implements io.Closer by closing the websocket 534 // connection. It is implemented so that a *dialResult 535 // value can be used as the result of a parallel.Try. 536 func (c *dialResult) Close() error { 537 return c.conn.Close() 538 } 539 540 // dialOpts holds the original dial options 541 // but adds some information for the local dial logic. 542 type dialOpts struct { 543 DialOpts 544 sniHostName string 545 // certPool holds a cert pool containing the CACert 546 // if there is one. 547 certPool *x509.CertPool 548 } 549 550 // dialAPI establishes a websocket connection to the RPC 551 // API websocket on the API server using Info. If multiple API addresses 552 // are provided in Info they will be tried concurrently - the first successful 553 // connection wins. 554 // 555 // It also returns the TLS configuration that it has derived from the Info. 556 func dialAPI(ctx context.Context, info *Info, opts0 DialOpts) (*dialResult, error) { 557 if len(info.Addrs) == 0 { 558 return nil, errors.New("no API addresses to connect to") 559 } 560 561 addrs := info.Addrs[:] 562 563 if info.Proxier != nil { 564 if err := info.Proxier.Start(); err != nil { 565 return nil, errors.Annotate(err, "starting proxy for api connection") 566 } 567 logger.Debugf("starting proxier for connection") 568 569 switch p := info.Proxier.(type) { 570 case jujuproxy.TunnelProxier: 571 logger.Debugf("tunnel proxy in use at %s on port %s", p.Host(), p.Port()) 572 addrs = []string{ 573 fmt.Sprintf("%s:%s", p.Host(), p.Port()), 574 } 575 default: 576 info.Proxier.Stop() 577 return nil, errors.New("unknown proxier provided") 578 } 579 } 580 581 opts := dialOpts{ 582 DialOpts: opts0, 583 sniHostName: info.SNIHostName, 584 } 585 if info.CACert != "" { 586 certPool, err := CreateCertPool(info.CACert) 587 if err != nil { 588 return nil, errors.Annotate(err, "cert pool creation failed") 589 } 590 opts.certPool = certPool 591 } 592 // Set opts.DialWebsocket and opts.Clock here rather than in open because 593 // some tests call dialAPI directly. 594 if opts.DialWebsocket == nil { 595 opts.DialWebsocket = gorillaDialWebsocket 596 } 597 if opts.IPAddrResolver == nil { 598 opts.IPAddrResolver = net.DefaultResolver 599 } 600 if opts.Clock == nil { 601 opts.Clock = clock.WallClock 602 } 603 if opts.DNSCache == nil { 604 opts.DNSCache = nopDNSCache{} 605 } 606 path, err := apiPath(info.ModelTag.Id(), "/api") 607 if err != nil { 608 return nil, errors.Trace(err) 609 } 610 611 // Encourage load balancing by shuffling controller addresses. 612 rand.Shuffle(len(addrs), func(i, j int) { addrs[i], addrs[j] = addrs[j], addrs[i] }) 613 614 if opts.VerifyCA != nil { 615 if err := verifyCAMulti(ctx, addrs, &opts); err != nil { 616 return nil, err 617 } 618 } 619 620 if opts.DialTimeout > 0 { 621 ctx1, cancel := utils.ContextWithTimeout(ctx, opts.Clock, opts.DialTimeout) 622 defer cancel() 623 ctx = ctx1 624 } 625 dialInfo, err := dialWebsocketMulti(ctx, addrs, path, opts) 626 if err != nil { 627 return nil, errors.Trace(err) 628 } 629 logger.Infof("connection established to %q", dialInfo.urlStr) 630 dialInfo.proxier = info.Proxier 631 return dialInfo, nil 632 } 633 634 // gorillaDialWebsocket makes a websocket connection using the 635 // gorilla websocket package. The ipAddr parameter holds the 636 // actual IP address that will be contacted - the host in urlStr 637 // is used only for TLS verification when tlsConfig.ServerName 638 // is empty. 639 func gorillaDialWebsocket(ctx context.Context, urlStr string, tlsConfig *tls.Config, ipAddr string) (jsoncodec.JSONConn, error) { 640 url, err := url.Parse(urlStr) 641 if err != nil { 642 return nil, errors.Trace(err) 643 } 644 // TODO(rogpeppe) We'd like to set Deadline here 645 // but that would break lots of tests that rely on 646 // setting a zero timeout. 647 netDialer := net.Dialer{} 648 dialer := &websocket.Dialer{ 649 NetDial: func(netw, addr string) (net.Conn, error) { 650 if addr == url.Host { 651 // Use pre-resolved IP address. The address 652 // may be different if a proxy is in use. 653 addr = ipAddr 654 } 655 return netDialer.DialContext(ctx, netw, addr) 656 }, 657 Proxy: proxy.DefaultConfig.GetProxy, 658 HandshakeTimeout: 45 * time.Second, 659 TLSClientConfig: tlsConfig, 660 } 661 // Note: no extra headers. 662 c, resp, err := dialer.Dial(urlStr, nil) 663 if err != nil { 664 if err == websocket.ErrBadHandshake { 665 // If ErrBadHandshake is returned, a non-nil response 666 // is returned so the client can react to auth errors 667 // (for example). 668 defer resp.Body.Close() 669 body, readErr := io.ReadAll(resp.Body) 670 if readErr == nil { 671 err = errors.Errorf( 672 "%s (%s)", 673 strings.TrimSpace(string(body)), 674 http.StatusText(resp.StatusCode), 675 ) 676 } 677 } 678 return nil, errors.Trace(err) 679 } 680 return jsoncodec.NewWebsocketConn(c), nil 681 } 682 683 type resolvedAddress struct { 684 host string 685 ip string 686 port string 687 } 688 689 type addressProvider struct { 690 dnsCache DNSCache 691 ipAddrResolver IPAddrResolver 692 693 // A pool of host addresses to be resolved to one or more IP addresses. 694 addrPool []string 695 696 // A pool of host addresses that got resolved via the DNS cache; these 697 // are kept separate so we can attempt to resolve them without the DNS 698 // cache when we run out of entries in AddrPool. 699 cachedAddrPool []string 700 resolvedAddrs []*resolvedAddress 701 } 702 703 func newAddressProvider(initialAddrs []string, dnsCache DNSCache, ipAddrResolver IPAddrResolver) *addressProvider { 704 return &addressProvider{ 705 dnsCache: dnsCache, 706 ipAddrResolver: ipAddrResolver, 707 addrPool: initialAddrs, 708 } 709 } 710 711 // next returns back either a successfully resolved address or the error that 712 // occurred while attempting to resolve the next address candidate. Calls to 713 // next return io.EOF to indicate that no more addresses are available. 714 func (ap *addressProvider) next(ctx context.Context) (*resolvedAddress, error) { 715 if len(ap.resolvedAddrs) == 0 { 716 // If we have ran out of addresses to resolve but we have 717 // resolved some via the DNS cache, make another pass for 718 // those with an empty DNS cache to refresh any stale entries. 719 if len(ap.addrPool) == 0 && len(ap.cachedAddrPool) > 0 { 720 ap.addrPool = ap.cachedAddrPool 721 ap.cachedAddrPool = nil 722 ap.dnsCache = emptyDNSCache{ap.dnsCache} 723 } 724 725 // Resolve the next host from the address pool 726 if len(ap.addrPool) != 0 { 727 next := ap.addrPool[0] 728 ap.addrPool = ap.addrPool[1:] 729 730 host, port, err := net.SplitHostPort(next) 731 if err != nil { 732 return nil, errors.Errorf("invalid address %q: %v", next, err) 733 } 734 735 ips := ap.dnsCache.Lookup(host) 736 if len(ips) > 0 { 737 ap.cachedAddrPool = append(ap.cachedAddrPool, next) 738 } else if isNumericHost(host) { 739 ips = []string{host} 740 } else { 741 var err error 742 ips, err = lookupIPAddr(ctx, host, ap.ipAddrResolver) 743 if err != nil { 744 return nil, errors.Errorf("cannot resolve %q: %v", host, err) 745 } 746 ap.dnsCache.Add(host, ips) 747 logger.Debugf("looked up %v -> %v", host, ips) 748 } 749 750 for _, ip := range ips { 751 ap.resolvedAddrs = append(ap.resolvedAddrs, &resolvedAddress{ 752 host: next, 753 ip: ip, 754 port: port, 755 }) 756 } 757 } 758 } 759 760 // Ran out of resolved addresses and cached addresses 761 if len(ap.resolvedAddrs) == 0 { 762 return nil, io.EOF 763 } 764 765 next := ap.resolvedAddrs[0] 766 ap.resolvedAddrs = ap.resolvedAddrs[1:] 767 return next, nil 768 } 769 770 // caRetrieveRes is an adaptor for returning CA certificate lookup results via 771 // calls to parallel.Try. 772 type caRetrieveRes struct { 773 host string 774 endpoint string 775 caCert *x509.Certificate 776 } 777 778 func (caRetrieveRes) Close() error { return nil } 779 780 // verifyCAMulti attempts to establish a TLS connection with one of the 781 // provided addresses, retrieve the CA certificate and validate it using the 782 // system root CAs. If that is not possible, the certificate verification will 783 // be delegated to the VerifyCA implementation specified in opts.DialOpts. 784 // 785 // If VerifyCA does not return an error, the CA cert is assumed to be trusted 786 // and will be appended to opt's certificate pool allowing secure websocket 787 // connections to proceed without certificate verification errors. Otherwise, 788 // the error reported by VerifyCA is returned back to the caller. 789 // 790 // For load-balancing purposes, all addresses are tested concurrently with the 791 // first retrieved CA cert being used for the verification tests. In addition, 792 // apart from the initial TLS handshake with the remote server, no other data 793 // is exchanged with the remote server. 794 func verifyCAMulti(ctx context.Context, addrs []string, opts *dialOpts) error { 795 dOpts := opts.DialOpts 796 if dOpts.DialTimeout > 0 { 797 ctx1, cancel := utils.ContextWithTimeout(ctx, dOpts.Clock, dOpts.DialTimeout) 798 defer cancel() 799 ctx = ctx1 800 } 801 802 try := parallel.NewTry(0, nil) 803 defer try.Kill() 804 805 addrProvider := newAddressProvider(addrs, opts.DNSCache, opts.IPAddrResolver) 806 tryRetrieveCaCertFn := func(ctx context.Context, addr *resolvedAddress) func(<-chan struct{}) (io.Closer, error) { 807 ipStr := net.JoinHostPort(addr.ip, addr.port) 808 return func(<-chan struct{}) (io.Closer, error) { 809 caCert, err := retrieveCACert(ctx, ipStr) 810 if err != nil { 811 return nil, err 812 } 813 814 return caRetrieveRes{ 815 host: addr.host, 816 endpoint: ipStr, 817 caCert: caCert, 818 }, nil 819 } 820 } 821 822 for { 823 resolvedAddr, err := addrProvider.next(ctx) 824 if err == io.EOF { 825 break 826 } else if err != nil { 827 recordTryError(try, err) 828 continue 829 } 830 831 err = try.Start(tryRetrieveCaCertFn(ctx, resolvedAddr)) 832 if err == parallel.ErrStopped { 833 break 834 } else if err != nil { 835 continue 836 } 837 838 select { 839 case <-opts.Clock.After(dOpts.DialAddressInterval): 840 case <-try.Dead(): 841 } 842 } 843 844 try.Close() 845 846 // If we are unable to fetch the CA either because it is not presented 847 // by the remote server OR due to an unsuccessful connection attempt 848 // we should skip the verification path and dial the server as if no 849 // VerifyCA implementation was provided. 850 result, err := try.Result() 851 if err != nil || result == nil { 852 logger.Debugf("unable to retrieve CA cert from remote host; skipping CA verification") 853 return nil 854 } 855 856 // Try to verify CA cert using the system roots. If the verification 857 // succeeds then we are done; tls connections will work out of the box. 858 res := result.(caRetrieveRes) 859 if _, err = res.caCert.Verify(x509.VerifyOptions{}); err == nil { 860 logger.Debugf("remote CA certificate trusted by system roots") 861 return nil 862 } 863 864 // Invoke the CA verifier; if the CA should be trusted, append it to 865 // the dialOpts certPool and proceed with the actual connection attempt. 866 err = opts.VerifyCA(res.host, res.endpoint, res.caCert) 867 if err == nil { 868 if opts.certPool == nil { 869 opts.certPool = x509.NewCertPool() 870 } 871 opts.certPool.AddCert(res.caCert) 872 } 873 874 return err 875 } 876 877 // retrieveCACert establishes an insecure TLS connection to addr and attempts 878 // to retrieve the CA cert presented by the server. If no CA cert is presented, 879 // retrieveCACert will returns nil, nil. 880 func retrieveCACert(ctx context.Context, addr string) (*x509.Certificate, error) { 881 netConn, err := new(net.Dialer).DialContext(ctx, "tcp", addr) 882 if err != nil { 883 return nil, err 884 } 885 886 conn := tls.Client(netConn, &tls.Config{InsecureSkipVerify: true}) 887 if err = conn.Handshake(); err != nil { 888 _ = netConn.Close() 889 return nil, err 890 } 891 defer func() { 892 _ = conn.Close() 893 _ = netConn.Close() 894 }() 895 896 for _, cert := range conn.ConnectionState().PeerCertificates { 897 if cert.IsCA { 898 return cert, nil 899 } 900 } 901 902 return nil, errors.New("no CA certificate presented by remote server") 903 } 904 905 // dialWebsocketMulti dials a websocket with one of the provided addresses, the 906 // specified URL path, TLS configuration, and dial options. Each of the 907 // specified addresses will be attempted concurrently, and the first 908 // successful connection will be returned. 909 func dialWebsocketMulti(ctx context.Context, addrs []string, path string, opts dialOpts) (*dialResult, error) { 910 // Prioritise non-dial errors over the normal "connection refused". 911 isDialError := func(err error) bool { 912 netErr, ok := errors.Cause(err).(*net.OpError) 913 if !ok { 914 return false 915 } 916 return netErr.Op == "dial" 917 } 918 combine := func(initial, other error) error { 919 if initial == nil || isDialError(initial) { 920 return other 921 } 922 if isDialError(other) { 923 return initial 924 } 925 return other 926 } 927 // Dial all addresses at reasonable intervals. 928 try := parallel.NewTry(0, combine) 929 defer try.Kill() 930 // Make a context that's cancelled when the try 931 // completes so that (for example) a slow DNS 932 // query will be cancelled if a previous try succeeds. 933 ctx, cancel := context.WithCancel(ctx) 934 go func() { 935 <-try.Dead() 936 cancel() 937 }() 938 tried := make(map[string]bool) 939 addrProvider := newAddressProvider(addrs, opts.DNSCache, opts.IPAddrResolver) 940 for { 941 resolvedAddr, err := addrProvider.next(ctx) 942 if err == io.EOF { 943 break 944 } else if err != nil { 945 recordTryError(try, err) 946 continue 947 } 948 949 ipStr := net.JoinHostPort(resolvedAddr.ip, resolvedAddr.port) 950 if tried[ipStr] { 951 continue 952 } 953 tried[ipStr] = true 954 err = startDialWebsocket(ctx, try, ipStr, resolvedAddr.host, path, opts) 955 if err == parallel.ErrStopped { 956 break 957 } 958 if err != nil { 959 return nil, errors.Trace(err) 960 } 961 select { 962 case <-opts.Clock.After(opts.DialAddressInterval): 963 case <-try.Dead(): 964 } 965 } 966 try.Close() 967 result, err := try.Result() 968 if err != nil { 969 return nil, errors.Trace(err) 970 } 971 return result.(*dialResult), nil 972 } 973 974 func lookupIPAddr(ctx context.Context, host string, resolver IPAddrResolver) ([]string, error) { 975 addrs, err := resolver.LookupIPAddr(ctx, host) 976 if err != nil { 977 return nil, errors.Trace(err) 978 } 979 ips := make([]string, 0, len(addrs)) 980 for _, addr := range addrs { 981 if addr.Zone != "" { 982 // Ignore IPv6 zone. Hopefully this shouldn't 983 // cause any problems in practice. 984 logger.Infof("ignoring IP address with zone %q", addr) 985 continue 986 } 987 ips = append(ips, addr.IP.String()) 988 } 989 return ips, nil 990 } 991 992 // recordTryError starts a try that just returns the given error. 993 // This is so that we can use the usual Try error combination 994 // logic even for errors that happen before we start a try. 995 func recordTryError(try *parallel.Try, err error) { 996 logger.Infof("%v", err) 997 _ = try.Start(func(_ <-chan struct{}) (io.Closer, error) { 998 return nil, errors.Trace(err) 999 }) 1000 } 1001 1002 var oneAttempt = retry.LimitCount(1, retry.Regular{ 1003 Min: 1, 1004 }) 1005 1006 // startDialWebsocket starts websocket connection to a single address 1007 // on the given try instance. 1008 func startDialWebsocket(ctx context.Context, try *parallel.Try, ipAddr, addr, path string, opts dialOpts) error { 1009 var openAttempt retry.Strategy 1010 if opts.RetryDelay > 0 { 1011 openAttempt = retry.Regular{ 1012 Total: opts.Timeout, 1013 Delay: opts.RetryDelay, 1014 Min: int(opts.Timeout / opts.RetryDelay), 1015 } 1016 } else { 1017 // Zero retry delay implies exactly one try. 1018 openAttempt = oneAttempt 1019 } 1020 d := dialer{ 1021 ctx: ctx, 1022 openAttempt: openAttempt, 1023 serverName: opts.sniHostName, 1024 ipAddr: ipAddr, 1025 urlStr: "wss://" + addr + path, 1026 addr: addr, 1027 opts: opts, 1028 } 1029 return try.Start(d.dial) 1030 } 1031 1032 type dialer struct { 1033 ctx context.Context 1034 openAttempt retry.Strategy 1035 1036 // serverName holds the SNI name to use 1037 // when connecting with a public certificate. 1038 serverName string 1039 1040 // addr holds the host:port that is being dialed. 1041 addr string 1042 1043 // addr holds the ipaddr:port (one of the addresses 1044 // that addr resolves to) that is being dialed. 1045 ipAddr string 1046 1047 // urlStr holds the URL that is being dialed. 1048 urlStr string 1049 1050 // opts holds the dial options. 1051 opts dialOpts 1052 } 1053 1054 // dial implements the function value expected by Try.Start 1055 // by dialing the websocket as specified in d and retrying 1056 // when appropriate. 1057 func (d dialer) dial(_ <-chan struct{}) (io.Closer, error) { 1058 a := retry.StartWithCancel(d.openAttempt, d.opts.Clock, d.ctx.Done()) 1059 var lastErr error = nil 1060 for a.Next() { 1061 conn, tlsConfig, err := d.dial1() 1062 if err == nil { 1063 return &dialResult{ 1064 conn: conn, 1065 addr: d.addr, 1066 ipAddr: d.ipAddr, 1067 urlStr: d.urlStr, 1068 tlsConfig: tlsConfig, 1069 }, nil 1070 } 1071 if isX509Error(err) || !a.More() { 1072 // certificate errors don't improve with retries. 1073 return nil, errors.Annotatef(err, "unable to connect to API") 1074 } 1075 lastErr = err 1076 } 1077 if lastErr == nil { 1078 logger.Debugf("no error, but not connected, probably cancelled before we started") 1079 return nil, parallel.ErrStopped 1080 } 1081 return nil, errors.Trace(lastErr) 1082 } 1083 1084 // dial1 makes a single dial attempt. 1085 func (d dialer) dial1() (jsoncodec.JSONConn, *tls.Config, error) { 1086 tlsConfig := NewTLSConfig(d.opts.certPool) 1087 tlsConfig.InsecureSkipVerify = d.opts.InsecureSkipVerify 1088 if d.opts.certPool == nil { 1089 tlsConfig.ServerName = d.serverName 1090 } 1091 logger.Tracef("dialing: %q %v", d.urlStr, d.ipAddr) 1092 conn, err := d.opts.DialWebsocket(d.ctx, d.urlStr, tlsConfig, d.ipAddr) 1093 if err == nil { 1094 logger.Debugf("successfully dialed %q", d.urlStr) 1095 return conn, tlsConfig, nil 1096 } 1097 if !isX509Error(err) { 1098 return nil, nil, errors.Trace(err) 1099 } 1100 if tlsConfig.RootCAs == nil || d.serverName == "" { 1101 // There's no private certificate or we don't have a 1102 // public hostname. In the former case, we've already 1103 // tried public certificates; in the latter, public cert 1104 // validation won't help, because you generally can't 1105 // obtain a public cert for a numeric IP address. In 1106 // both those cases, we won't succeed when trying again 1107 // because a cert error isn't temporary, so return 1108 // immediately. 1109 // 1110 // Note that the error returned from 1111 // websocket.DialConfig always includes the location in 1112 // the message. 1113 return nil, nil, errors.Trace(err) 1114 } 1115 // It's possible we're inappropriately using the private 1116 // CA certificate, so retry immediately with the public one. 1117 tlsConfig.RootCAs = nil 1118 tlsConfig.ServerName = d.serverName 1119 conn, rootCAErr := d.opts.DialWebsocket(d.ctx, d.urlStr, tlsConfig, d.ipAddr) 1120 if rootCAErr != nil { 1121 logger.Debugf("failed to dial websocket using fallback public CA: %v", rootCAErr) 1122 // We return the original error as it's usually more meaningful. 1123 return nil, nil, errors.Trace(err) 1124 } 1125 return conn, tlsConfig, nil 1126 } 1127 1128 // NewTLSConfig returns a new *tls.Config suitable for connecting to a Juju 1129 // API server. If certPool is non-nil, we use it as the config's RootCAs, 1130 // and the server name is set to "juju-apiserver". 1131 func NewTLSConfig(certPool *x509.CertPool) *tls.Config { 1132 tlsConfig := jujuhttp.SecureTLSConfig() 1133 if certPool != nil { 1134 // We want to be specific here (rather than just using "anything"). 1135 // See commit 7fc118f015d8480dfad7831788e4b8c0432205e8 (PR 899). 1136 tlsConfig.RootCAs = certPool 1137 tlsConfig.ServerName = "juju-apiserver" 1138 } 1139 return tlsConfig 1140 } 1141 1142 // isNumericHost reports whether the given host name is 1143 // a numeric IP address. 1144 func isNumericHost(host string) bool { 1145 return net.ParseIP(host) != nil 1146 } 1147 1148 // isX509Error reports whether the given websocket error 1149 // results from an X509 problem. 1150 func isX509Error(err error) bool { 1151 switch errType := errors.Cause(err).(type) { 1152 case *websocket.CloseError: 1153 return errType.Code == websocket.CloseTLSHandshake 1154 case x509.CertificateInvalidError, 1155 x509.HostnameError, 1156 x509.InsecureAlgorithmError, 1157 x509.UnhandledCriticalExtension, 1158 x509.UnknownAuthorityError, 1159 x509.ConstraintViolationError, 1160 x509.SystemRootsError: 1161 return true 1162 default: 1163 return false 1164 } 1165 } 1166 1167 // APICall places a call to the remote machine. 1168 // 1169 // This fills out the rpc.Request on the given facade, version for a given 1170 // object id, and the specific RPC method. It marshalls the Arguments, and will 1171 // unmarshall the result into the response object that is supplied. 1172 func (s *state) APICall(facade string, vers int, id, method string, args, response interface{}) error { 1173 return s.client.Call(rpc.Request{ 1174 Type: facade, 1175 Version: vers, 1176 Id: id, 1177 Action: method, 1178 }, args, response) 1179 } 1180 1181 func (s *state) Close() error { 1182 err := s.client.Close() 1183 select { 1184 case <-s.closed: 1185 default: 1186 close(s.closed) 1187 } 1188 <-s.broken 1189 if s.proxier != nil { 1190 s.proxier.Stop() 1191 } 1192 return err 1193 } 1194 1195 // BakeryClient implements api.Connection. 1196 func (s *state) BakeryClient() base.MacaroonDischarger { 1197 return s.bakeryClient 1198 } 1199 1200 // Broken implements api.Connection. 1201 func (s *state) Broken() <-chan struct{} { 1202 return s.broken 1203 } 1204 1205 // IsBroken implements api.Connection. 1206 func (s *state) IsBroken() bool { 1207 select { 1208 case <-s.broken: 1209 return true 1210 default: 1211 } 1212 if err := s.ping(); err != nil { 1213 logger.Debugf("connection ping failed: %v", err) 1214 return true 1215 } 1216 return false 1217 } 1218 1219 // Addr returns the address used to connect to the API server. 1220 func (s *state) Addr() string { 1221 return s.addr 1222 } 1223 1224 // IPAddr returns the resolved IP address that was used to 1225 // connect to the API server. 1226 func (s *state) IPAddr() string { 1227 return s.ipAddr 1228 } 1229 1230 // IsProxied indicates if this connection was proxied 1231 func (s *state) IsProxied() bool { 1232 return s.proxier != nil 1233 } 1234 1235 // Proxy returns the proxy being used with this connection if one is being used. 1236 func (s *state) Proxy() jujuproxy.Proxier { 1237 return s.proxier 1238 } 1239 1240 // ModelTag implements base.APICaller.ModelTag. 1241 func (s *state) ModelTag() (names.ModelTag, bool) { 1242 return s.modelTag, s.modelTag.Id() != "" 1243 } 1244 1245 // ControllerTag implements base.APICaller.ControllerTag. 1246 func (s *state) ControllerTag() names.ControllerTag { 1247 return s.controllerTag 1248 } 1249 1250 // APIHostPorts returns addresses that may be used to connect 1251 // to the API server, including the address used to connect. 1252 // 1253 // The addresses are scoped (public, cloud-internal, etc.), so 1254 // the client may choose which addresses to attempt. For the 1255 // Juju CLI, all addresses must be attempted, as the CLI may 1256 // be invoked both within and outside the model (think 1257 // private clouds). 1258 func (s *state) APIHostPorts() []network.MachineHostPorts { 1259 // NOTE: We're making a copy of s.hostPorts before returning it, 1260 // for safety. 1261 hostPorts := make([]network.MachineHostPorts, len(s.hostPorts)) 1262 for i, servers := range s.hostPorts { 1263 hostPorts[i] = append(network.MachineHostPorts{}, servers...) 1264 } 1265 return hostPorts 1266 } 1267 1268 // PublicDNSName returns the host name for which an officially 1269 // signed certificate will be used for TLS connection to the server. 1270 // If empty, the private Juju CA certificate must be used to verify 1271 // the connection. 1272 func (s *state) PublicDNSName() string { 1273 return s.publicDNSName 1274 } 1275 1276 // BestFacadeVersion compares the versions of facades that we know about, and 1277 // the versions available from the server, and reports back what version is the 1278 // 'best available' to use. 1279 // TODO(jam) this is the eventual implementation of what version of a given 1280 // Facade we will want to use. It needs to line up the versions that the server 1281 // reports to us, with the versions that our client knows how to use. 1282 func (s *state) BestFacadeVersion(facade string) int { 1283 return facades.BestVersion(facadeVersions[facade], s.facadeVersions[facade]) 1284 } 1285 1286 // serverRoot returns the cached API server address and port used 1287 // to login, prefixed with "<URI scheme>://" (usually https). 1288 func (s *state) serverRoot() string { 1289 return s.serverScheme + "://" + s.serverRootAddress 1290 } 1291 1292 func (s *state) isLoggedIn() bool { 1293 return atomic.LoadInt32(&s.loggedIn) == 1 1294 } 1295 1296 func (s *state) setLoggedIn() { 1297 atomic.StoreInt32(&s.loggedIn, 1) 1298 } 1299 1300 // emptyDNSCache implements DNSCache by 1301 // never returning any entries but writing any 1302 // added entries to the embedded DNSCache object. 1303 type emptyDNSCache struct { 1304 DNSCache 1305 } 1306 1307 func (emptyDNSCache) Lookup(host string) []string { 1308 return nil 1309 } 1310 1311 type nopDNSCache struct{} 1312 1313 func (nopDNSCache) Lookup(host string) []string { 1314 return nil 1315 } 1316 1317 func (nopDNSCache) Add(host string, ips []string) { 1318 }