github.com/mattyw/juju@v0.0.0-20140610034352-732aecd63861/state/api/apiclient.go (about)

     1  // Copyright 2012, 2013 Canonical Ltd.
     2  // Licensed under the AGPLv3, see LICENCE file for details.
     3  
     4  package api
     5  
     6  import (
     7  	"crypto/tls"
     8  	"crypto/x509"
     9  	"fmt"
    10  	"io"
    11  	"strings"
    12  	"time"
    13  
    14  	"code.google.com/p/go.net/websocket"
    15  	"github.com/juju/loggo"
    16  	"github.com/juju/names"
    17  	"github.com/juju/utils"
    18  	"github.com/juju/utils/parallel"
    19  
    20  	"github.com/juju/juju/cert"
    21  	"github.com/juju/juju/instance"
    22  	"github.com/juju/juju/rpc"
    23  	"github.com/juju/juju/rpc/jsoncodec"
    24  	"github.com/juju/juju/state/api/params"
    25  )
    26  
    27  var logger = loggo.GetLogger("juju.state.api")
    28  
    29  // PingPeriod defines how often the internal connection health check
    30  // will run. It's a variable so it can be changed in tests.
    31  var PingPeriod = 1 * time.Minute
    32  
    33  type State struct {
    34  	client *rpc.Conn
    35  	conn   *websocket.Conn
    36  
    37  	// addr is the address used to connect to the API server.
    38  	addr string
    39  
    40  	// environTag holds the environment tag once we're connected
    41  	environTag string
    42  
    43  	// hostPorts is the API server addresses returned from Login,
    44  	// which the client may cache and use for failover.
    45  	hostPorts [][]instance.HostPort
    46  
    47  	// authTag holds the authenticated entity's tag after login.
    48  	authTag string
    49  
    50  	// broken is a channel that gets closed when the connection is
    51  	// broken.
    52  	broken chan struct{}
    53  
    54  	// tag and password hold the cached login credentials.
    55  	tag      string
    56  	password string
    57  
    58  	// serverRoot holds the cached API server address and port we used
    59  	// to login, with a https:// prefix.
    60  	serverRoot string
    61  
    62  	// certPool holds the cert pool that is used to authenticate the tls
    63  	// connections to the API.
    64  	certPool *x509.CertPool
    65  }
    66  
    67  // Info encapsulates information about a server holding juju state and
    68  // can be used to make a connection to it.
    69  type Info struct {
    70  	// Addrs holds the addresses of the state servers.
    71  	Addrs []string
    72  
    73  	// CACert holds the CA certificate that will be used
    74  	// to validate the state server's certificate, in PEM format.
    75  	CACert string
    76  
    77  	// Tag holds the name of the entity that is connecting.
    78  	// If this and the password are empty, no login attempt will be made
    79  	// (this is to allow tests to access the API to check that operations
    80  	// fail when not logged in).
    81  	Tag string
    82  
    83  	// Password holds the password for the administrator or connecting entity.
    84  	Password string
    85  
    86  	// Nonce holds the nonce used when provisioning the machine. Used
    87  	// only by the machine agent.
    88  	Nonce string `yaml:",omitempty"`
    89  
    90  	// Environ holds the environ tag for the environment we are trying to
    91  	// connect to.
    92  	EnvironTag string
    93  }
    94  
    95  // DialOpts holds configuration parameters that control the
    96  // Dialing behavior when connecting to a state server.
    97  type DialOpts struct {
    98  	// DialAddressInterval is the amount of time to wait
    99  	// before starting to dial another address.
   100  	DialAddressInterval time.Duration
   101  
   102  	// Timeout is the amount of time to wait contacting
   103  	// a state server.
   104  	Timeout time.Duration
   105  
   106  	// RetryDelay is the amount of time to wait between
   107  	// unsucssful connection attempts.
   108  	RetryDelay time.Duration
   109  }
   110  
   111  // DefaultDialOpts returns a DialOpts representing the default
   112  // parameters for contacting a state server.
   113  func DefaultDialOpts() DialOpts {
   114  	return DialOpts{
   115  		DialAddressInterval: 50 * time.Millisecond,
   116  		Timeout:             10 * time.Minute,
   117  		RetryDelay:          2 * time.Second,
   118  	}
   119  }
   120  
   121  func Open(info *Info, opts DialOpts) (*State, error) {
   122  	if len(info.Addrs) == 0 {
   123  		return nil, fmt.Errorf("no API addresses to connect to")
   124  	}
   125  	pool := x509.NewCertPool()
   126  	xcert, err := cert.ParseCert(info.CACert)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  	pool.AddCert(xcert)
   131  
   132  	environUUID := ""
   133  	if info.EnvironTag != "" {
   134  		_, envUUID, err := names.ParseTag(info.EnvironTag, names.EnvironTagKind)
   135  		if err != nil {
   136  			return nil, err
   137  		}
   138  		environUUID = envUUID
   139  	}
   140  	// Dial all addresses at reasonable intervals.
   141  	try := parallel.NewTry(0, nil)
   142  	defer try.Kill()
   143  	var addrs []string
   144  	for _, addr := range info.Addrs {
   145  		if strings.HasPrefix(addr, "localhost:") {
   146  			addrs = append(addrs, addr)
   147  			break
   148  		}
   149  	}
   150  	if len(addrs) == 0 {
   151  		addrs = info.Addrs
   152  	}
   153  	for _, addr := range addrs {
   154  		err := dialWebsocket(addr, environUUID, opts, pool, try)
   155  		if err == parallel.ErrStopped {
   156  			break
   157  		}
   158  		if err != nil {
   159  			return nil, err
   160  		}
   161  		select {
   162  		case <-time.After(opts.DialAddressInterval):
   163  		case <-try.Dead():
   164  		}
   165  	}
   166  	try.Close()
   167  	result, err := try.Result()
   168  	if err != nil {
   169  		return nil, err
   170  	}
   171  	conn := result.(*websocket.Conn)
   172  	logger.Infof("connection established to %q", conn.RemoteAddr())
   173  
   174  	client := rpc.NewConn(jsoncodec.NewWebsocket(conn), nil)
   175  	client.Start()
   176  	st := &State{
   177  		client:     client,
   178  		conn:       conn,
   179  		addr:       conn.Config().Location.Host,
   180  		serverRoot: "https://" + conn.Config().Location.Host,
   181  		tag:        info.Tag,
   182  		password:   info.Password,
   183  		certPool:   pool,
   184  	}
   185  	if info.Tag != "" || info.Password != "" {
   186  		if err := st.Login(info.Tag, info.Password, info.Nonce); err != nil {
   187  			conn.Close()
   188  			return nil, err
   189  		}
   190  	}
   191  	st.broken = make(chan struct{})
   192  	go st.heartbeatMonitor(PingPeriod)
   193  	return st, nil
   194  }
   195  
   196  func dialWebsocket(addr, environUUID string, opts DialOpts, rootCAs *x509.CertPool, try *parallel.Try) error {
   197  	cfg, err := setUpWebsocket(addr, environUUID, rootCAs)
   198  	if err != nil {
   199  		return err
   200  	}
   201  	return try.Start(newWebsocketDialer(cfg, opts))
   202  }
   203  
   204  func setUpWebsocket(addr, environUUID string, rootCAs *x509.CertPool) (*websocket.Config, error) {
   205  	// origin is required by the WebSocket API, used for "origin policy"
   206  	// in websockets. We pass localhost to satisfy the API; it is
   207  	// inconsequential to us.
   208  	const origin = "http://localhost/"
   209  	tail := "/"
   210  	if environUUID != "" {
   211  		tail = "/environment/" + environUUID + "/api"
   212  	}
   213  	cfg, err := websocket.NewConfig("wss://"+addr+tail, origin)
   214  	if err != nil {
   215  		return nil, err
   216  	}
   217  	cfg.TlsConfig = &tls.Config{
   218  		RootCAs:    rootCAs,
   219  		ServerName: "anything",
   220  	}
   221  	return cfg, nil
   222  }
   223  
   224  // newWebsocketDialer returns a function that
   225  // can be passed to utils/parallel.Try.Start.
   226  func newWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan struct{}) (io.Closer, error) {
   227  	openAttempt := utils.AttemptStrategy{
   228  		Total: opts.Timeout,
   229  		Delay: opts.RetryDelay,
   230  	}
   231  	return func(stop <-chan struct{}) (io.Closer, error) {
   232  		for a := openAttempt.Start(); a.Next(); {
   233  			select {
   234  			case <-stop:
   235  				return nil, parallel.ErrStopped
   236  			default:
   237  			}
   238  			logger.Infof("dialing %q", cfg.Location)
   239  			conn, err := websocket.DialConfig(cfg)
   240  			if err == nil {
   241  				return conn, nil
   242  			}
   243  			if a.HasNext() {
   244  				logger.Debugf("error dialing %q, will retry: %v", cfg.Location, err)
   245  			} else {
   246  				logger.Infof("error dialing %q: %v", cfg.Location, err)
   247  				return nil, fmt.Errorf("unable to connect to %q", cfg.Location)
   248  			}
   249  		}
   250  		panic("unreachable")
   251  	}
   252  }
   253  
   254  func (s *State) heartbeatMonitor(pingPeriod time.Duration) {
   255  	for {
   256  		if err := s.Ping(); err != nil {
   257  			close(s.broken)
   258  			return
   259  		}
   260  		time.Sleep(pingPeriod)
   261  	}
   262  }
   263  
   264  func (s *State) Ping() error {
   265  	return s.Call("Pinger", "", "Ping", nil, nil)
   266  }
   267  
   268  // Call invokes a low-level RPC method of the given objType, id, and
   269  // request, passing the given parameters and filling in the response
   270  // results. This should not be used directly by clients.
   271  // TODO (dimitern) Add tests for all client-facing objects to verify
   272  // we return the correct error when invoking Call("Object",
   273  // "non-empty-id",...)
   274  func (s *State) Call(objType, id, request string, args, response interface{}) error {
   275  	err := s.client.Call(rpc.Request{
   276  		Type:   objType,
   277  		Id:     id,
   278  		Action: request,
   279  	}, args, response)
   280  	return params.ClientError(err)
   281  }
   282  
   283  func (s *State) Close() error {
   284  	return s.client.Close()
   285  }
   286  
   287  // Broken returns a channel that's closed when the connection is broken.
   288  func (s *State) Broken() <-chan struct{} {
   289  	return s.broken
   290  }
   291  
   292  // RPCClient returns the RPC client for the state, so that testing
   293  // functions can tickle parts of the API that the conventional entry
   294  // points don't reach. This is exported for testing purposes only.
   295  func (s *State) RPCClient() *rpc.Conn {
   296  	return s.client
   297  }
   298  
   299  // Addr returns the address used to connect to the API server.
   300  func (s *State) Addr() string {
   301  	return s.addr
   302  }
   303  
   304  // EnvironTag returns the Environment Tag describing the environment we are
   305  // connected to.
   306  func (s *State) EnvironTag() string {
   307  	return s.environTag
   308  }
   309  
   310  // APIHostPorts returns addresses that may be used to connect
   311  // to the API server, including the address used to connect.
   312  //
   313  // The addresses are scoped (public, cloud-internal, etc.), so
   314  // the client may choose which addresses to attempt. For the
   315  // Juju CLI, all addresses must be attempted, as the CLI may
   316  // be invoked both within and outside the environment (think
   317  // private clouds).
   318  func (s *State) APIHostPorts() [][]instance.HostPort {
   319  	hostPorts := make([][]instance.HostPort, len(s.hostPorts))
   320  	for i, server := range s.hostPorts {
   321  		hostPorts[i] = append([]instance.HostPort{}, server...)
   322  	}
   323  	return hostPorts
   324  }