github.com/mattyw/juju@v0.0.0-20140610034352-732aecd63861/state/api/apiclient.go (about) 1 // Copyright 2012, 2013 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package api 5 6 import ( 7 "crypto/tls" 8 "crypto/x509" 9 "fmt" 10 "io" 11 "strings" 12 "time" 13 14 "code.google.com/p/go.net/websocket" 15 "github.com/juju/loggo" 16 "github.com/juju/names" 17 "github.com/juju/utils" 18 "github.com/juju/utils/parallel" 19 20 "github.com/juju/juju/cert" 21 "github.com/juju/juju/instance" 22 "github.com/juju/juju/rpc" 23 "github.com/juju/juju/rpc/jsoncodec" 24 "github.com/juju/juju/state/api/params" 25 ) 26 27 var logger = loggo.GetLogger("juju.state.api") 28 29 // PingPeriod defines how often the internal connection health check 30 // will run. It's a variable so it can be changed in tests. 31 var PingPeriod = 1 * time.Minute 32 33 type State struct { 34 client *rpc.Conn 35 conn *websocket.Conn 36 37 // addr is the address used to connect to the API server. 38 addr string 39 40 // environTag holds the environment tag once we're connected 41 environTag string 42 43 // hostPorts is the API server addresses returned from Login, 44 // which the client may cache and use for failover. 45 hostPorts [][]instance.HostPort 46 47 // authTag holds the authenticated entity's tag after login. 48 authTag string 49 50 // broken is a channel that gets closed when the connection is 51 // broken. 52 broken chan struct{} 53 54 // tag and password hold the cached login credentials. 55 tag string 56 password string 57 58 // serverRoot holds the cached API server address and port we used 59 // to login, with a https:// prefix. 60 serverRoot string 61 62 // certPool holds the cert pool that is used to authenticate the tls 63 // connections to the API. 64 certPool *x509.CertPool 65 } 66 67 // Info encapsulates information about a server holding juju state and 68 // can be used to make a connection to it. 69 type Info struct { 70 // Addrs holds the addresses of the state servers. 71 Addrs []string 72 73 // CACert holds the CA certificate that will be used 74 // to validate the state server's certificate, in PEM format. 75 CACert string 76 77 // Tag holds the name of the entity that is connecting. 78 // If this and the password are empty, no login attempt will be made 79 // (this is to allow tests to access the API to check that operations 80 // fail when not logged in). 81 Tag string 82 83 // Password holds the password for the administrator or connecting entity. 84 Password string 85 86 // Nonce holds the nonce used when provisioning the machine. Used 87 // only by the machine agent. 88 Nonce string `yaml:",omitempty"` 89 90 // Environ holds the environ tag for the environment we are trying to 91 // connect to. 92 EnvironTag string 93 } 94 95 // DialOpts holds configuration parameters that control the 96 // Dialing behavior when connecting to a state server. 97 type DialOpts struct { 98 // DialAddressInterval is the amount of time to wait 99 // before starting to dial another address. 100 DialAddressInterval time.Duration 101 102 // Timeout is the amount of time to wait contacting 103 // a state server. 104 Timeout time.Duration 105 106 // RetryDelay is the amount of time to wait between 107 // unsucssful connection attempts. 108 RetryDelay time.Duration 109 } 110 111 // DefaultDialOpts returns a DialOpts representing the default 112 // parameters for contacting a state server. 113 func DefaultDialOpts() DialOpts { 114 return DialOpts{ 115 DialAddressInterval: 50 * time.Millisecond, 116 Timeout: 10 * time.Minute, 117 RetryDelay: 2 * time.Second, 118 } 119 } 120 121 func Open(info *Info, opts DialOpts) (*State, error) { 122 if len(info.Addrs) == 0 { 123 return nil, fmt.Errorf("no API addresses to connect to") 124 } 125 pool := x509.NewCertPool() 126 xcert, err := cert.ParseCert(info.CACert) 127 if err != nil { 128 return nil, err 129 } 130 pool.AddCert(xcert) 131 132 environUUID := "" 133 if info.EnvironTag != "" { 134 _, envUUID, err := names.ParseTag(info.EnvironTag, names.EnvironTagKind) 135 if err != nil { 136 return nil, err 137 } 138 environUUID = envUUID 139 } 140 // Dial all addresses at reasonable intervals. 141 try := parallel.NewTry(0, nil) 142 defer try.Kill() 143 var addrs []string 144 for _, addr := range info.Addrs { 145 if strings.HasPrefix(addr, "localhost:") { 146 addrs = append(addrs, addr) 147 break 148 } 149 } 150 if len(addrs) == 0 { 151 addrs = info.Addrs 152 } 153 for _, addr := range addrs { 154 err := dialWebsocket(addr, environUUID, opts, pool, try) 155 if err == parallel.ErrStopped { 156 break 157 } 158 if err != nil { 159 return nil, err 160 } 161 select { 162 case <-time.After(opts.DialAddressInterval): 163 case <-try.Dead(): 164 } 165 } 166 try.Close() 167 result, err := try.Result() 168 if err != nil { 169 return nil, err 170 } 171 conn := result.(*websocket.Conn) 172 logger.Infof("connection established to %q", conn.RemoteAddr()) 173 174 client := rpc.NewConn(jsoncodec.NewWebsocket(conn), nil) 175 client.Start() 176 st := &State{ 177 client: client, 178 conn: conn, 179 addr: conn.Config().Location.Host, 180 serverRoot: "https://" + conn.Config().Location.Host, 181 tag: info.Tag, 182 password: info.Password, 183 certPool: pool, 184 } 185 if info.Tag != "" || info.Password != "" { 186 if err := st.Login(info.Tag, info.Password, info.Nonce); err != nil { 187 conn.Close() 188 return nil, err 189 } 190 } 191 st.broken = make(chan struct{}) 192 go st.heartbeatMonitor(PingPeriod) 193 return st, nil 194 } 195 196 func dialWebsocket(addr, environUUID string, opts DialOpts, rootCAs *x509.CertPool, try *parallel.Try) error { 197 cfg, err := setUpWebsocket(addr, environUUID, rootCAs) 198 if err != nil { 199 return err 200 } 201 return try.Start(newWebsocketDialer(cfg, opts)) 202 } 203 204 func setUpWebsocket(addr, environUUID string, rootCAs *x509.CertPool) (*websocket.Config, error) { 205 // origin is required by the WebSocket API, used for "origin policy" 206 // in websockets. We pass localhost to satisfy the API; it is 207 // inconsequential to us. 208 const origin = "http://localhost/" 209 tail := "/" 210 if environUUID != "" { 211 tail = "/environment/" + environUUID + "/api" 212 } 213 cfg, err := websocket.NewConfig("wss://"+addr+tail, origin) 214 if err != nil { 215 return nil, err 216 } 217 cfg.TlsConfig = &tls.Config{ 218 RootCAs: rootCAs, 219 ServerName: "anything", 220 } 221 return cfg, nil 222 } 223 224 // newWebsocketDialer returns a function that 225 // can be passed to utils/parallel.Try.Start. 226 func newWebsocketDialer(cfg *websocket.Config, opts DialOpts) func(<-chan struct{}) (io.Closer, error) { 227 openAttempt := utils.AttemptStrategy{ 228 Total: opts.Timeout, 229 Delay: opts.RetryDelay, 230 } 231 return func(stop <-chan struct{}) (io.Closer, error) { 232 for a := openAttempt.Start(); a.Next(); { 233 select { 234 case <-stop: 235 return nil, parallel.ErrStopped 236 default: 237 } 238 logger.Infof("dialing %q", cfg.Location) 239 conn, err := websocket.DialConfig(cfg) 240 if err == nil { 241 return conn, nil 242 } 243 if a.HasNext() { 244 logger.Debugf("error dialing %q, will retry: %v", cfg.Location, err) 245 } else { 246 logger.Infof("error dialing %q: %v", cfg.Location, err) 247 return nil, fmt.Errorf("unable to connect to %q", cfg.Location) 248 } 249 } 250 panic("unreachable") 251 } 252 } 253 254 func (s *State) heartbeatMonitor(pingPeriod time.Duration) { 255 for { 256 if err := s.Ping(); err != nil { 257 close(s.broken) 258 return 259 } 260 time.Sleep(pingPeriod) 261 } 262 } 263 264 func (s *State) Ping() error { 265 return s.Call("Pinger", "", "Ping", nil, nil) 266 } 267 268 // Call invokes a low-level RPC method of the given objType, id, and 269 // request, passing the given parameters and filling in the response 270 // results. This should not be used directly by clients. 271 // TODO (dimitern) Add tests for all client-facing objects to verify 272 // we return the correct error when invoking Call("Object", 273 // "non-empty-id",...) 274 func (s *State) Call(objType, id, request string, args, response interface{}) error { 275 err := s.client.Call(rpc.Request{ 276 Type: objType, 277 Id: id, 278 Action: request, 279 }, args, response) 280 return params.ClientError(err) 281 } 282 283 func (s *State) Close() error { 284 return s.client.Close() 285 } 286 287 // Broken returns a channel that's closed when the connection is broken. 288 func (s *State) Broken() <-chan struct{} { 289 return s.broken 290 } 291 292 // RPCClient returns the RPC client for the state, so that testing 293 // functions can tickle parts of the API that the conventional entry 294 // points don't reach. This is exported for testing purposes only. 295 func (s *State) RPCClient() *rpc.Conn { 296 return s.client 297 } 298 299 // Addr returns the address used to connect to the API server. 300 func (s *State) Addr() string { 301 return s.addr 302 } 303 304 // EnvironTag returns the Environment Tag describing the environment we are 305 // connected to. 306 func (s *State) EnvironTag() string { 307 return s.environTag 308 } 309 310 // APIHostPorts returns addresses that may be used to connect 311 // to the API server, including the address used to connect. 312 // 313 // The addresses are scoped (public, cloud-internal, etc.), so 314 // the client may choose which addresses to attempt. For the 315 // Juju CLI, all addresses must be attempted, as the CLI may 316 // be invoked both within and outside the environment (think 317 // private clouds). 318 func (s *State) APIHostPorts() [][]instance.HostPort { 319 hostPorts := make([][]instance.HostPort, len(s.hostPorts)) 320 for i, server := range s.hostPorts { 321 hostPorts[i] = append([]instance.HostPort{}, server...) 322 } 323 return hostPorts 324 }