github.com/niedbalski/juju@v0.0.0-20190215020005-8ff100488e47/worker/raft/rafttransport/worker.go (about) 1 // Copyright 2018 Canonical Ltd. 2 // Licensed under the AGPLv3, see LICENCE file for details. 3 4 package rafttransport 5 6 import ( 7 "context" 8 "crypto/tls" 9 "log" 10 "net" 11 "net/http" 12 "time" 13 14 "github.com/hashicorp/raft" 15 "github.com/juju/clock" 16 "github.com/juju/errors" 17 "github.com/juju/loggo" 18 "github.com/juju/pubsub" 19 "github.com/juju/replicaset" 20 "gopkg.in/juju/worker.v1" 21 "gopkg.in/juju/worker.v1/catacomb" 22 23 "github.com/juju/juju/api" 24 "github.com/juju/juju/apiserver/apiserverhttp" 25 "github.com/juju/juju/apiserver/httpcontext" 26 "github.com/juju/juju/worker/raft/raftutil" 27 ) 28 29 var ( 30 logger = loggo.GetLogger("juju.worker.raft.rafttransport") 31 ) 32 33 const ( 34 maxPoolSize = replicaset.MaxPeers 35 ) 36 37 // Config is the configuration required for running an apiserver-based 38 // raft transport worker. 39 type Config struct { 40 // APIInfo contains the information, excluding addresses, 41 // required to connect to an API server. 42 APIInfo *api.Info 43 44 // Authenticator is the HTTP request authenticator to use for 45 // the raft endpoint. 46 Authenticator httpcontext.Authenticator 47 48 // DialConn is the function to use for dialing connections to 49 // other API servers. 50 DialConn DialConnFunc 51 52 // Hub is the central hub to which the worker will subscribe 53 // for notification of local address changes. 54 Hub *pubsub.StructuredHub 55 56 // Mux is the API server HTTP mux into which the handler will 57 // be installed. 58 Mux *apiserverhttp.Mux 59 60 // Path is the path of the raft HTTP endpoint. 61 Path string 62 63 // LocalID is the raft.ServerID of the agent running this worker. 64 LocalID raft.ServerID 65 66 // Timeout, if non-zero, is the timeout to apply to transport 67 // operations. See raft.NetworkTransportConfig.Timeout for more 68 // details. 69 Timeout time.Duration 70 71 // TLSConfig is the TLS configuration to use for making 72 // connections to API servers. 73 TLSConfig *tls.Config 74 75 // Clock is used for timing out the Addr getter - if the 76 // peergrouper isn't publishing good API addresses in a timely 77 // fashion it's better to fail and log than to hang indefinitely. 78 Clock clock.Clock 79 } 80 81 // DialConnFunc is type of function used by the transport for 82 // dialing a TLS connection to another API server. The worker 83 // will send an HTTP request over the connection to upgrade it. 84 type DialConnFunc func(ctx context.Context, addr string, tlsConfig *tls.Config) (net.Conn, error) 85 86 // Validate validates the raft worker configuration. 87 func (config Config) Validate() error { 88 if config.APIInfo == nil { 89 return errors.NotValidf("nil APIInfo") 90 } 91 if config.Authenticator == nil { 92 return errors.NotValidf("nil Authenticator") 93 } 94 if config.DialConn == nil { 95 return errors.NotValidf("nil DialConn") 96 } 97 if config.Hub == nil { 98 return errors.NotValidf("nil Hub") 99 } 100 if config.Mux == nil { 101 return errors.NotValidf("nil Mux") 102 } 103 if config.Path == "" { 104 return errors.NotValidf("empty Path") 105 } 106 if config.LocalID == "" { 107 return errors.NotValidf("empty LocalID") 108 } 109 if config.TLSConfig == nil { 110 return errors.NotValidf("nil TLSConfig") 111 } 112 return nil 113 } 114 115 // NewWorker returns a new apiserver-based raft transport worker, 116 // with the given configuration. The worker itself implements 117 // raft.Transport. 118 func NewWorker(config Config) (worker.Worker, error) { 119 if err := config.Validate(); err != nil { 120 return nil, errors.Trace(err) 121 } 122 123 apiPorts := config.APIInfo.Ports() 124 if n := len(apiPorts); n != 1 { 125 return nil, errors.Errorf("api.Info has %d unique ports, expected 1", n) 126 } 127 128 w := &Worker{ 129 config: config, 130 connections: make(chan net.Conn), 131 dialRequests: make(chan dialRequest), 132 apiPort: apiPorts[0], 133 } 134 135 const logPrefix = "[transport] " 136 logWriter := &raftutil.LoggoWriter{logger, loggo.DEBUG} 137 logLogger := log.New(logWriter, logPrefix, 0) 138 stream, err := newStreamLayer(config.LocalID, config.Hub, w.connections, config.Clock, &Dialer{ 139 APIInfo: config.APIInfo, 140 DialRaw: w.dialRaw, 141 Path: config.Path, 142 }) 143 if err != nil { 144 return nil, errors.Trace(err) 145 } 146 transport := raft.NewNetworkTransportWithConfig(&raft.NetworkTransportConfig{ 147 Logger: logLogger, 148 MaxPool: maxPoolSize, 149 Stream: stream, 150 Timeout: config.Timeout, 151 }) 152 w.Transport = transport 153 154 var h http.Handler = NewHandler(w.connections, w.catacomb.Dying()) 155 h = &httpcontext.BasicAuthHandler{ 156 Handler: h, 157 Authenticator: w.config.Authenticator, 158 Authorizer: httpcontext.AuthorizerFunc(controllerAuthorizer), 159 } 160 h = &httpcontext.ImpliedModelHandler{ 161 Handler: h, 162 ModelUUID: w.config.APIInfo.ModelTag.Id(), 163 } 164 165 w.config.Mux.AddHandler("GET", w.config.Path, h) 166 167 if err := catacomb.Invoke(catacomb.Plan{ 168 Site: &w.catacomb, 169 Work: func() error { 170 defer transport.Close() 171 defer w.config.Mux.RemoveHandler("GET", w.config.Path) 172 return w.loop() 173 }, 174 Init: []worker.Worker{stream}, 175 }); err != nil { 176 transport.Close() 177 w.config.Mux.RemoveHandler("GET", w.config.Path) 178 return nil, errors.Trace(err) 179 } 180 return w, nil 181 } 182 183 // Worker is a worker that manages a raft.Transport. 184 type Worker struct { 185 raft.Transport 186 187 catacomb catacomb.Catacomb 188 config Config 189 connections chan net.Conn 190 dialRequests chan dialRequest 191 tlsConfig *tls.Config 192 apiPort int 193 } 194 195 type dialRequest struct { 196 ctx context.Context 197 address string 198 result chan<- dialResult 199 } 200 201 type dialResult struct { 202 conn net.Conn 203 err error 204 } 205 206 // Kill is part of the worker.Worker interface. 207 func (w *Worker) Kill() { 208 w.catacomb.Kill(nil) 209 } 210 211 // Wait is part of the worker.Worker interface. 212 func (w *Worker) Wait() error { 213 return w.catacomb.Wait() 214 } 215 216 // dialRaw dials a new TLS connection to the controller identified 217 // by the given address. The address is expected to be the stringified 218 // tag of a controller machine agent. The resulting connection is 219 // appropriate for use as Dialer.DialRaw. 220 func (w *Worker) dialRaw(address raft.ServerAddress, timeout time.Duration) (net.Conn, error) { 221 // Give precedence to the worker dying. 222 select { 223 case <-w.catacomb.Dying(): 224 return nil, w.errDialWorkerStopped() 225 default: 226 } 227 228 ctx := context.Background() 229 if timeout != 0 { 230 var cancel context.CancelFunc 231 ctx, cancel = context.WithTimeout(ctx, timeout) 232 defer cancel() 233 } 234 235 resultCh := make(chan dialResult) 236 req := dialRequest{ 237 ctx: ctx, 238 address: string(address), 239 result: resultCh, 240 } 241 select { 242 case <-w.catacomb.Dying(): 243 return nil, w.errDialWorkerStopped() 244 case <-ctx.Done(): 245 return nil, dialRequestTimeoutError{} 246 case w.dialRequests <- req: 247 } 248 249 select { 250 case res := <-resultCh: 251 return res.conn, res.err 252 case <-ctx.Done(): 253 return nil, dialRequestTimeoutError{} 254 case <-w.catacomb.Dying(): 255 return nil, w.errDialWorkerStopped() 256 } 257 } 258 259 func (w *Worker) errDialWorkerStopped() error { 260 err := w.catacomb.Err() 261 if err != nil && err != w.catacomb.ErrDying() { 262 return dialWorkerStoppedError{err} 263 } 264 return dialWorkerStoppedError{ 265 errors.New("worker stopped"), 266 } 267 } 268 269 func (w *Worker) loop() error { 270 for { 271 select { 272 case <-w.catacomb.Dying(): 273 return w.catacomb.ErrDying() 274 case req := <-w.dialRequests: 275 go w.handleDial(req) 276 } 277 } 278 } 279 280 func (w *Worker) handleDial(req dialRequest) { 281 conn, err := w.config.DialConn(req.ctx, req.address, w.config.TLSConfig) 282 select { 283 case req.result <- dialResult{conn, err}: 284 return 285 case <-req.ctx.Done(): 286 case <-w.catacomb.Dying(): 287 } 288 if err == nil { 289 // result wasn't delivered, close connection 290 conn.Close() 291 } 292 } 293 294 // DialConn dials a TLS connection to the API server with the 295 // given address, using the given TLS configuration. This will 296 // be used for requesting the raft endpoint, upgrading to a 297 // raw connection for inter-node raft communications. 298 // 299 // TODO: this function needs to be made proxy-aware. 300 func DialConn(ctx context.Context, addr string, tlsConfig *tls.Config) (net.Conn, error) { 301 dialer := &net.Dialer{} 302 if deadline, ok := ctx.Deadline(); ok { 303 dialer.Deadline = deadline 304 } 305 306 ctx, cancel := context.WithCancel(ctx) 307 defer cancel() 308 309 canceled := make(chan struct{}) 310 go func() { 311 <-ctx.Done() 312 if ctx.Err() == context.Canceled { 313 close(canceled) 314 } 315 }() 316 dialer.Cancel = canceled 317 318 return tls.DialWithDialer(dialer, "tcp", addr, tlsConfig) 319 } 320 321 func controllerAuthorizer(authInfo httpcontext.AuthInfo) error { 322 if authInfo.Controller { 323 return nil 324 } 325 return errors.New("controller agents only") 326 }