github.com/mhilton/juju-juju@v0.0.0-20150901100907-a94dd2c73455/api/apiclient.go (about)

     1  // Copyright 2012-2015 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package api
     5  
     6  import (
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"io"
    10  	"net/http"
    11  	"strings"
    12  	"time"
    13  
    14  	"github.com/juju/errors"
    15  	"github.com/juju/loggo"
    16  	"github.com/juju/names"
    17  	"github.com/juju/utils"
    18  	"github.com/juju/utils/parallel"
    19  	"golang.org/x/net/websocket"
    20  
    21  	"github.com/juju/juju/apiserver/params"
    22  	"github.com/juju/juju/network"
    23  	"github.com/juju/juju/rpc"
    24  	"github.com/juju/juju/rpc/jsoncodec"
    25  	"github.com/juju/juju/version"
    26  )
    27  
    28  var logger = loggo.GetLogger("juju.api")
    29  
    30  // PingPeriod defines how often the internal connection health check
    31  // will run. It's a variable so it can be changed in tests.
    32  var PingPeriod = 1 * time.Minute
    33  
    34  type State struct {
    35  	client *rpc.Conn
    36  	conn   *websocket.Conn
    37  
    38  	// addr is the address used to connect to the API server.
    39  	addr string
    40  
    41  	// environTag holds the environment tag once we're connected
    42  	environTag string
    43  
    44  	// serverTag holds the server tag once we're connected.
    45  	// This is only set with newer apiservers where they are using
    46  	// the v1 login mechansim.
    47  	serverTag string
    48  
    49  	// serverVersion holds the version of the API server that we are
    50  	// connected to.  It is possible that this version is 0 if the
    51  	// server does not report this during login.
    52  	serverVersion version.Number
    53  
    54  	// hostPorts is the API server addresses returned from Login,
    55  	// which the client may cache and use for failover.
    56  	hostPorts [][]network.HostPort
    57  
    58  	// facadeVersions holds the versions of all facades as reported by
    59  	// Login
    60  	facadeVersions map[string][]int
    61  
    62  	// authTag holds the authenticated entity's tag after login.
    63  	authTag names.Tag
    64  
    65  	// broken is a channel that gets closed when the connection is
    66  	// broken.
    67  	broken chan struct{}
    68  
    69  	// closed is a channel that gets closed when State.Close is called.
    70  	closed chan struct{}
    71  
    72  	// tag and password hold the cached login credentials.
    73  	tag      string
    74  	password string
    75  
    76  	// serverRootAddress holds the cached API server address and port used
    77  	// to login.
    78  	serverRootAddress string
    79  
    80  	// serverScheme is the URI scheme of the API Server
    81  	serverScheme string
    82  
    83  	// certPool holds the cert pool that is used to authenticate the tls
    84  	// connections to the API.
    85  	certPool *x509.CertPool
    86  }
    87  
    88  // Open establishes a connection to the API server using the Info
    89  // given, returning a State instance which can be used to make API
    90  // requests.
    91  //
    92  // See Connect for details of the connection mechanics.
    93  func Open(info *Info, opts DialOpts) (Connection, error) {
    94  	return open(info, opts, (*State).Login)
    95  }
    96  
    97  // This unexported open method is used both directly above in the Open
    98  // function, and also the OpenWithVersion function below to explicitly cause
    99  // the API server to think that the client is older than it really is.
   100  func open(info *Info, opts DialOpts, loginFunc func(st *State, tag, pwd, nonce string) error) (Connection, error) {
   101  	conn, err := Connect(info, "", nil, opts)
   102  	if err != nil {
   103  		return nil, errors.Trace(err)
   104  	}
   105  
   106  	client := rpc.NewConn(jsoncodec.NewWebsocket(conn), nil)
   107  	client.Start()
   108  	st := &State{
   109  		client:            client,
   110  		conn:              conn,
   111  		addr:              conn.Config().Location.Host,
   112  		serverScheme:      "https",
   113  		serverRootAddress: conn.Config().Location.Host,
   114  		// why are the contents of the tag (username and password) written into the
   115  		// state structure BEFORE login ?!?
   116  		tag:      toString(info.Tag),
   117  		password: info.Password,
   118  		certPool: conn.Config().TlsConfig.RootCAs,
   119  	}
   120  	if info.Tag != nil || info.Password != "" {
   121  		if err := loginFunc(st, info.Tag.String(), info.Password, info.Nonce); err != nil {
   122  			conn.Close()
   123  			return nil, err
   124  		}
   125  	}
   126  	st.broken = make(chan struct{})
   127  	st.closed = make(chan struct{})
   128  	go st.heartbeatMonitor()
   129  	return st, nil
   130  }
   131  
   132  // OpenWithVersion uses an explicit version of the Admin facade to call Login
   133  // on. This allows the caller to pretend to be an older client, and is used
   134  // only in testing.
   135  func OpenWithVersion(info *Info, opts DialOpts, loginVersion int) (Connection, error) {
   136  	var loginFunc func(st *State, tag, pwd, nonce string) error
   137  	switch loginVersion {
   138  	case 0:
   139  		loginFunc = (*State).loginV0
   140  	case 1:
   141  		loginFunc = (*State).loginV1
   142  	case 2:
   143  		loginFunc = (*State).loginV2
   144  	default:
   145  		return nil, errors.NotSupportedf("loginVersion %d", loginVersion)
   146  	}
   147  	return open(info, opts, loginFunc)
   148  }
   149  
   150  // Connect establishes a websocket connection to the API server using
   151  // the Info, API path tail and (optional) request headers provided. If
   152  // multiple API addresses are provided in Info they will be tried
   153  // concurrently - the first successful connection wins.
   154  //
   155  // The path tail may be blank, in which case the default value will be
   156  // used. Otherwise, it must start with a "/".
   157  func Connect(info *Info, pathTail string, header http.Header, opts DialOpts) (*websocket.Conn, error) {
   158  	if len(info.Addrs) == 0 {
   159  		return nil, errors.New("no API addresses to connect to")
   160  	}
   161  	if pathTail != "" && !strings.HasPrefix(pathTail, "/") {
   162  		return nil, errors.New(`path tail must start with "/"`)
   163  	}
   164  
   165  	pool, err := CreateCertPool(info.CACert)
   166  	if err != nil {
   167  		return nil, errors.Annotate(err, "cert pool creation failed")
   168  	}
   169  
   170  	path := makeAPIPath(info.EnvironTag.Id(), pathTail)
   171  
   172  	// Dial all addresses at reasonable intervals.
   173  	try := parallel.NewTry(0, nil)
   174  	defer try.Kill()
   175  	for _, addr := range info.Addrs {
   176  		err := dialWebsocket(addr, path, header, opts, pool, try)
   177  		if err == parallel.ErrStopped {
   178  			break
   179  		}
   180  		if err != nil {
   181  			return nil, errors.Trace(err)
   182  		}
   183  		select {
   184  		case <-time.After(opts.DialAddressInterval):
   185  		case <-try.Dead():
   186  		}
   187  	}
   188  	try.Close()
   189  	result, err := try.Result()
   190  	if err != nil {
   191  		return nil, errors.Trace(err)
   192  	}
   193  	conn := result.(*websocket.Conn)
   194  	logger.Infof("connection established to %q", conn.RemoteAddr())
   195  	return conn, nil
   196  }
   197  
   198  // makeAPIPath builds the path to connect to based on the tail given
   199  // and whether the environment UUID is set.
   200  func makeAPIPath(envUUID, tail string) string {
   201  	if envUUID == "" {
   202  		if tail == "" {
   203  			tail = "/"
   204  		}
   205  		return tail
   206  	}
   207  	if tail == "" {
   208  		tail = "/api"
   209  	}
   210  	return "/environment/" + envUUID + tail
   211  }
   212  
   213  // toString returns the value of a tag's String method, or "" if the tag is nil.
   214  func toString(tag names.Tag) string {
   215  	if tag == nil {
   216  		return ""
   217  	}
   218  	return tag.String()
   219  }
   220  
   221  func dialWebsocket(addr, path string, header http.Header, opts DialOpts, rootCAs *x509.CertPool, try *parallel.Try) error {
   222  	cfg, err := setUpWebsocket(addr, path, header, rootCAs)
   223  	if err != nil {
   224  		return err
   225  	}
   226  	return try.Start(newWebsocketDialer(cfg, opts))
   227  }
   228  
   229  func setUpWebsocket(addr, path string, header http.Header, rootCAs *x509.CertPool) (*websocket.Config, error) {
   230  	// origin is required by the WebSocket API, used for "origin policy"
   231  	// in websockets. We pass localhost to satisfy the API; it is
   232  	// inconsequential to us.
   233  	const origin = "http://localhost/"
   234  	cfg, err := websocket.NewConfig("wss://"+addr+path, origin)
   235  	if err != nil {
   236  		return nil, errors.Trace(err)
   237  	}
   238  	cfg.TlsConfig = &tls.Config{
   239  		RootCAs:    rootCAs,
   240  		ServerName: "juju-apiserver",
   241  	}
   242  	cfg.Header = header
   243  	return cfg, nil
   244  }
   245  
   246  // newWebsocketDialer returns a function that
   247  // can be passed to utils/parallel.Try.Start.
   248  var newWebsocketDialer = createWebsocketDialer
   249  
   250  func createWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan struct{}) (io.Closer, error) {
   251  	openAttempt := utils.AttemptStrategy{
   252  		Total: opts.Timeout,
   253  		Delay: opts.RetryDelay,
   254  	}
   255  	return func(stop <-chan struct{}) (io.Closer, error) {
   256  		for a := openAttempt.Start(); a.Next(); {
   257  			select {
   258  			case <-stop:
   259  				return nil, parallel.ErrStopped
   260  			default:
   261  			}
   262  			logger.Infof("dialing %q", cfg.Location)
   263  			conn, err := websocket.DialConfig(cfg)
   264  			if err == nil {
   265  				return conn, nil
   266  			}
   267  			if a.HasNext() {
   268  				logger.Debugf("error dialing %q, will retry: %v", cfg.Location, err)
   269  			} else {
   270  				logger.Infof("error dialing %q: %v", cfg.Location, err)
   271  				return nil, errors.Errorf("unable to connect to %q", cfg.Location)
   272  			}
   273  		}
   274  		panic("unreachable")
   275  	}
   276  }
   277  
   278  func (s *State) heartbeatMonitor() {
   279  	for {
   280  		if err := s.Ping(); err != nil {
   281  			close(s.broken)
   282  			return
   283  		}
   284  		select {
   285  		case <-time.After(PingPeriod):
   286  		case <-s.closed:
   287  		}
   288  	}
   289  }
   290  
   291  func (s *State) Ping() error {
   292  	return s.APICall("Pinger", s.BestFacadeVersion("Pinger"), "", "Ping", nil, nil)
   293  }
   294  
   295  // APICall places a call to the remote machine.
   296  //
   297  // This fills out the rpc.Request on the given facade, version for a given
   298  // object id, and the specific RPC method. It marshalls the Arguments, and will
   299  // unmarshall the result into the response object that is supplied.
   300  func (s *State) APICall(facade string, version int, id, method string, args, response interface{}) error {
   301  	err := s.client.Call(rpc.Request{
   302  		Type:    facade,
   303  		Version: version,
   304  		Id:      id,
   305  		Action:  method,
   306  	}, args, response)
   307  	return params.ClientError(err)
   308  }
   309  
   310  func (s *State) Close() error {
   311  	err := s.client.Close()
   312  	select {
   313  	case <-s.closed:
   314  	default:
   315  		close(s.closed)
   316  	}
   317  	<-s.broken
   318  	return err
   319  }
   320  
   321  // Broken returns a channel that's closed when the connection is broken.
   322  func (s *State) Broken() <-chan struct{} {
   323  	return s.broken
   324  }
   325  
   326  // RPCClient returns the RPC client for the state, so that testing
   327  // functions can tickle parts of the API that the conventional entry
   328  // points don't reach. This is exported for testing purposes only.
   329  func (s *State) RPCClient() *rpc.Conn {
   330  	return s.client
   331  }
   332  
   333  // Addr returns the address used to connect to the API server.
   334  func (s *State) Addr() string {
   335  	return s.addr
   336  }
   337  
   338  // EnvironTag returns the tag of the environment we are connected to.
   339  func (s *State) EnvironTag() (names.EnvironTag, error) {
   340  	return names.ParseEnvironTag(s.environTag)
   341  }
   342  
   343  // ServerTag returns the tag of the server we are connected to.
   344  func (s *State) ServerTag() (names.EnvironTag, error) {
   345  	return names.ParseEnvironTag(s.serverTag)
   346  }
   347  
   348  // APIHostPorts returns addresses that may be used to connect
   349  // to the API server, including the address used to connect.
   350  //
   351  // The addresses are scoped (public, cloud-internal, etc.), so
   352  // the client may choose which addresses to attempt. For the
   353  // Juju CLI, all addresses must be attempted, as the CLI may
   354  // be invoked both within and outside the environment (think
   355  // private clouds).
   356  func (s *State) APIHostPorts() [][]network.HostPort {
   357  	// NOTE: We're making a copy of s.hostPorts before returning it,
   358  	// for safety.
   359  	hostPorts := make([][]network.HostPort, len(s.hostPorts))
   360  	for i, server := range s.hostPorts {
   361  		hostPorts[i] = append([]network.HostPort{}, server...)
   362  	}
   363  	return hostPorts
   364  }
   365  
   366  // AllFacadeVersions returns what versions we know about for all facades
   367  func (s *State) AllFacadeVersions() map[string][]int {
   368  	facades := make(map[string][]int, len(s.facadeVersions))
   369  	for name, versions := range s.facadeVersions {
   370  		facades[name] = append([]int{}, versions...)
   371  	}
   372  	return facades
   373  }
   374  
   375  // BestFacadeVersion compares the versions of facades that we know about, and
   376  // the versions available from the server, and reports back what version is the
   377  // 'best available' to use.
   378  // TODO(jam) this is the eventual implementation of what version of a given
   379  // Facade we will want to use. It needs to line up the versions that the server
   380  // reports to us, with the versions that our client knows how to use.
   381  func (s *State) BestFacadeVersion(facade string) int {
   382  	return bestVersion(facadeVersions[facade], s.facadeVersions[facade])
   383  }
   384  
   385  // serverRoot returns the cached API server address and port used
   386  // to login, prefixed with "<URI scheme>://" (usually https).
   387  func (s *State) serverRoot() string {
   388  	return s.serverScheme + "://" + s.serverRootAddress
   389  }