github.com/djenriquez/nomad-1@v0.8.1/nomad/rpc.go (about) 1 package nomad 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "errors" 8 "fmt" 9 "io" 10 "math/rand" 11 "net" 12 "net/rpc" 13 "strings" 14 "time" 15 16 metrics "github.com/armon/go-metrics" 17 "github.com/hashicorp/consul/lib" 18 memdb "github.com/hashicorp/go-memdb" 19 "github.com/hashicorp/nomad/helper/pool" 20 "github.com/hashicorp/nomad/nomad/state" 21 "github.com/hashicorp/nomad/nomad/structs" 22 "github.com/hashicorp/raft" 23 "github.com/hashicorp/yamux" 24 "github.com/ugorji/go/codec" 25 ) 26 27 const ( 28 // maxQueryTime is used to bound the limit of a blocking query 29 maxQueryTime = 300 * time.Second 30 31 // defaultQueryTime is the amount of time we block waiting for a change 32 // if no time is specified. Previously we would wait the maxQueryTime. 33 defaultQueryTime = 300 * time.Second 34 35 // Warn if the Raft command is larger than this. 36 // If it's over 1MB something is probably being abusive. 37 raftWarnSize = 1024 * 1024 38 39 // enqueueLimit caps how long we will wait to enqueue 40 // a new Raft command. Something is probably wrong if this 41 // value is ever reached. However, it prevents us from blocking 42 // the requesting goroutine forever. 43 enqueueLimit = 30 * time.Second 44 ) 45 46 // RPCContext provides metadata about the RPC connection. 47 type RPCContext struct { 48 // Conn exposes the raw connection. 49 Conn net.Conn 50 51 // Session exposes the multiplexed connection session. 52 Session *yamux.Session 53 54 // TLS marks whether the RPC is over a TLS based connection 55 TLS bool 56 57 // VerifiedChains is is the Verified certificates presented by the incoming 58 // connection. 59 VerifiedChains [][]*x509.Certificate 60 61 // NodeID marks the NodeID that initiated the connection. 62 NodeID string 63 } 64 65 // listen is used to listen for incoming RPC connections 66 func (s *Server) listen(ctx context.Context) { 67 defer close(s.listenerCh) 68 for { 69 select { 70 case <-ctx.Done(): 71 s.logger.Println("[INFO] nomad.rpc: Closing server RPC connection") 72 return 73 default: 74 } 75 76 // Accept a connection 77 conn, err := s.rpcListener.Accept() 78 if err != nil { 79 if s.shutdown { 80 return 81 } 82 83 select { 84 case <-ctx.Done(): 85 return 86 default: 87 } 88 89 s.logger.Printf("[ERR] nomad.rpc: failed to accept RPC conn: %v", err) 90 continue 91 } 92 93 go s.handleConn(ctx, conn, &RPCContext{Conn: conn}) 94 metrics.IncrCounter([]string{"nomad", "rpc", "accept_conn"}, 1) 95 } 96 } 97 98 // handleConn is used to determine if this is a Raft or 99 // Nomad type RPC connection and invoke the correct handler 100 func (s *Server) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { 101 // Read a single byte 102 buf := make([]byte, 1) 103 if _, err := conn.Read(buf); err != nil { 104 if err != io.EOF { 105 s.logger.Printf("[ERR] nomad.rpc: failed to read byte: %v", err) 106 } 107 conn.Close() 108 return 109 } 110 111 // Enforce TLS if EnableRPC is set 112 if s.config.TLSConfig.EnableRPC && !rpcCtx.TLS && pool.RPCType(buf[0]) != pool.RpcTLS { 113 if !s.config.TLSConfig.RPCUpgradeMode { 114 s.logger.Printf("[WARN] nomad.rpc: Non-TLS connection attempted from %s with RequireTLS set", conn.RemoteAddr().String()) 115 conn.Close() 116 return 117 } 118 } 119 120 // Switch on the byte 121 switch pool.RPCType(buf[0]) { 122 case pool.RpcNomad: 123 // Create an RPC Server and handle the request 124 server := rpc.NewServer() 125 s.setupRpcServer(server, rpcCtx) 126 s.handleNomadConn(ctx, conn, server) 127 128 // Remove any potential mapping between a NodeID to this connection and 129 // close the underlying connection. 130 s.removeNodeConn(rpcCtx) 131 132 case pool.RpcRaft: 133 metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1) 134 s.raftLayer.Handoff(ctx, conn) 135 136 case pool.RpcMultiplex: 137 s.handleMultiplex(ctx, conn, rpcCtx) 138 139 case pool.RpcTLS: 140 if s.rpcTLS == nil { 141 s.logger.Printf("[WARN] nomad.rpc: TLS connection attempted, server not configured for TLS") 142 conn.Close() 143 return 144 } 145 conn = tls.Server(conn, s.rpcTLS) 146 147 // Force a handshake so we can get information about the TLS connection 148 // state. 149 tlsConn, ok := conn.(*tls.Conn) 150 if !ok { 151 s.logger.Printf("[ERR] nomad.rpc: expected TLS connection but got %T", conn) 152 conn.Close() 153 return 154 } 155 156 if err := tlsConn.Handshake(); err != nil { 157 s.logger.Printf("[WARN] nomad.rpc: failed TLS handshake from connection from %v: %v", tlsConn.RemoteAddr(), err) 158 conn.Close() 159 return 160 } 161 162 // Update the connection context with the fact that the connection is 163 // using TLS 164 rpcCtx.TLS = true 165 166 // Store the verified chains so they can be inspected later. 167 state := tlsConn.ConnectionState() 168 rpcCtx.VerifiedChains = state.VerifiedChains 169 170 s.handleConn(ctx, conn, rpcCtx) 171 172 case pool.RpcStreaming: 173 s.handleStreamingConn(conn) 174 175 case pool.RpcMultiplexV2: 176 s.handleMultiplexV2(ctx, conn, rpcCtx) 177 178 default: 179 s.logger.Printf("[ERR] nomad.rpc: unrecognized RPC byte: %v", buf[0]) 180 conn.Close() 181 return 182 } 183 } 184 185 // handleMultiplex is used to multiplex a single incoming connection 186 // using the Yamux multiplexer 187 func (s *Server) handleMultiplex(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { 188 defer func() { 189 // Remove any potential mapping between a NodeID to this connection and 190 // close the underlying connection. 191 s.removeNodeConn(rpcCtx) 192 conn.Close() 193 }() 194 195 conf := yamux.DefaultConfig() 196 conf.LogOutput = s.config.LogOutput 197 server, err := yamux.Server(conn, conf) 198 if err != nil { 199 s.logger.Printf("[ERR] nomad.rpc: multiplex failed to create yamux server: %v", err) 200 return 201 } 202 203 // Update the context to store the yamux session 204 rpcCtx.Session = server 205 206 // Create the RPC server for this connection 207 rpcServer := rpc.NewServer() 208 s.setupRpcServer(rpcServer, rpcCtx) 209 210 for { 211 sub, err := server.Accept() 212 if err != nil { 213 if err != io.EOF { 214 s.logger.Printf("[ERR] nomad.rpc: multiplex conn accept failed: %v", err) 215 } 216 return 217 } 218 go s.handleNomadConn(ctx, sub, rpcServer) 219 } 220 } 221 222 // handleNomadConn is used to service a single Nomad RPC connection 223 func (s *Server) handleNomadConn(ctx context.Context, conn net.Conn, server *rpc.Server) { 224 defer conn.Close() 225 rpcCodec := pool.NewServerCodec(conn) 226 for { 227 select { 228 case <-ctx.Done(): 229 s.logger.Println("[INFO] nomad.rpc: Closing server RPC connection") 230 return 231 case <-s.shutdownCh: 232 return 233 default: 234 } 235 236 if err := server.ServeRequest(rpcCodec); err != nil { 237 if err != io.EOF && !strings.Contains(err.Error(), "closed") { 238 s.logger.Printf("[ERR] nomad.rpc: RPC error: %v (%v)", err, conn) 239 metrics.IncrCounter([]string{"nomad", "rpc", "request_error"}, 1) 240 } 241 return 242 } 243 metrics.IncrCounter([]string{"nomad", "rpc", "request"}, 1) 244 } 245 } 246 247 // handleStreamingConn is used to handle a single Streaming Nomad RPC connection. 248 func (s *Server) handleStreamingConn(conn net.Conn) { 249 defer conn.Close() 250 251 // Decode the header 252 var header structs.StreamingRpcHeader 253 decoder := codec.NewDecoder(conn, structs.MsgpackHandle) 254 if err := decoder.Decode(&header); err != nil { 255 if err != io.EOF && !strings.Contains(err.Error(), "closed") { 256 s.logger.Printf("[ERR] nomad.rpc: Streaming RPC error: %v (%v)", err, conn) 257 metrics.IncrCounter([]string{"nomad", "streaming_rpc", "request_error"}, 1) 258 } 259 260 return 261 } 262 263 ack := structs.StreamingRpcAck{} 264 handler, err := s.streamingRpcs.GetHandler(header.Method) 265 if err != nil { 266 s.logger.Printf("[ERR] nomad.rpc: Streaming RPC error: %v (%v)", err, conn) 267 metrics.IncrCounter([]string{"nomad", "streaming_rpc", "request_error"}, 1) 268 ack.Error = err.Error() 269 } 270 271 // Send the acknowledgement 272 encoder := codec.NewEncoder(conn, structs.MsgpackHandle) 273 if err := encoder.Encode(ack); err != nil { 274 conn.Close() 275 return 276 } 277 278 if ack.Error != "" { 279 return 280 } 281 282 // Invoke the handler 283 metrics.IncrCounter([]string{"nomad", "streaming_rpc", "request"}, 1) 284 handler(conn) 285 } 286 287 // handleMultiplexV2 is used to multiplex a single incoming connection 288 // using the Yamux multiplexer. Version 2 handling allows a single connection to 289 // switch streams between regulars RPCs and Streaming RPCs. 290 func (s *Server) handleMultiplexV2(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { 291 defer func() { 292 // Remove any potential mapping between a NodeID to this connection and 293 // close the underlying connection. 294 s.removeNodeConn(rpcCtx) 295 conn.Close() 296 }() 297 298 conf := yamux.DefaultConfig() 299 conf.LogOutput = s.config.LogOutput 300 server, err := yamux.Server(conn, conf) 301 if err != nil { 302 s.logger.Printf("[ERR] nomad.rpc: multiplex_v2 failed to create yamux server: %v", err) 303 return 304 } 305 306 // Update the context to store the yamux session 307 rpcCtx.Session = server 308 309 // Create the RPC server for this connection 310 rpcServer := rpc.NewServer() 311 s.setupRpcServer(rpcServer, rpcCtx) 312 313 for { 314 // Accept a new stream 315 sub, err := server.Accept() 316 if err != nil { 317 if err != io.EOF { 318 s.logger.Printf("[ERR] nomad.rpc: multiplex_v2 conn accept failed: %v", err) 319 } 320 return 321 } 322 323 // Read a single byte 324 buf := make([]byte, 1) 325 if _, err := sub.Read(buf); err != nil { 326 if err != io.EOF { 327 s.logger.Printf("[ERR] nomad.rpc: multiplex_v2 failed to read byte: %v", err) 328 } 329 return 330 } 331 332 // Determine which handler to use 333 switch pool.RPCType(buf[0]) { 334 case pool.RpcNomad: 335 go s.handleNomadConn(ctx, sub, rpcServer) 336 case pool.RpcStreaming: 337 go s.handleStreamingConn(sub) 338 339 default: 340 s.logger.Printf("[ERR] nomad.rpc: multiplex_v2 unrecognized RPC byte: %v", buf[0]) 341 return 342 } 343 } 344 345 } 346 347 // forward is used to forward to a remote region or to forward to the local leader 348 // Returns a bool of if forwarding was performed, as well as any error 349 func (s *Server) forward(method string, info structs.RPCInfo, args interface{}, reply interface{}) (bool, error) { 350 var firstCheck time.Time 351 352 region := info.RequestRegion() 353 if region == "" { 354 return true, fmt.Errorf("missing target RPC") 355 } 356 357 // Handle region forwarding 358 if region != s.config.Region { 359 // Mark that we are forwarding the RPC 360 info.SetForwarded() 361 err := s.forwardRegion(region, method, args, reply) 362 return true, err 363 } 364 365 // Check if we can allow a stale read 366 if info.IsRead() && info.AllowStaleRead() { 367 return false, nil 368 } 369 370 CHECK_LEADER: 371 // Find the leader 372 isLeader, remoteServer := s.getLeader() 373 374 // Handle the case we are the leader 375 if isLeader { 376 return false, nil 377 } 378 379 // Handle the case of a known leader 380 if remoteServer != nil { 381 // Mark that we are forwarding the RPC 382 info.SetForwarded() 383 err := s.forwardLeader(remoteServer, method, args, reply) 384 return true, err 385 } 386 387 // Gate the request until there is a leader 388 if firstCheck.IsZero() { 389 firstCheck = time.Now() 390 } 391 if time.Now().Sub(firstCheck) < s.config.RPCHoldTimeout { 392 jitter := lib.RandomStagger(s.config.RPCHoldTimeout / structs.JitterFraction) 393 select { 394 case <-time.After(jitter): 395 goto CHECK_LEADER 396 case <-s.shutdownCh: 397 } 398 } 399 400 // No leader found and hold time exceeded 401 return true, structs.ErrNoLeader 402 } 403 404 // getLeader returns if the current node is the leader, and if not 405 // then it returns the leader which is potentially nil if the cluster 406 // has not yet elected a leader. 407 func (s *Server) getLeader() (bool, *serverParts) { 408 // Check if we are the leader 409 if s.IsLeader() { 410 return true, nil 411 } 412 413 // Get the leader 414 leader := s.raft.Leader() 415 if leader == "" { 416 return false, nil 417 } 418 419 // Lookup the server 420 s.peerLock.RLock() 421 server := s.localPeers[leader] 422 s.peerLock.RUnlock() 423 424 // Server could be nil 425 return false, server 426 } 427 428 // forwardLeader is used to forward an RPC call to the leader, or fail if no leader 429 func (s *Server) forwardLeader(server *serverParts, method string, args interface{}, reply interface{}) error { 430 // Handle a missing server 431 if server == nil { 432 return structs.ErrNoLeader 433 } 434 return s.connPool.RPC(s.config.Region, server.Addr, server.MajorVersion, method, args, reply) 435 } 436 437 // forwardServer is used to forward an RPC call to a particular server 438 func (s *Server) forwardServer(server *serverParts, method string, args interface{}, reply interface{}) error { 439 // Handle a missing server 440 if server == nil { 441 return errors.New("must be given a valid server address") 442 } 443 return s.connPool.RPC(s.config.Region, server.Addr, server.MajorVersion, method, args, reply) 444 } 445 446 // forwardRegion is used to forward an RPC call to a remote region, or fail if no servers 447 func (s *Server) forwardRegion(region, method string, args interface{}, reply interface{}) error { 448 // Bail if we can't find any servers 449 s.peerLock.RLock() 450 servers := s.peers[region] 451 if len(servers) == 0 { 452 s.peerLock.RUnlock() 453 s.logger.Printf("[WARN] nomad.rpc: RPC request for region '%s', no path found", 454 region) 455 return structs.ErrNoRegionPath 456 } 457 458 // Select a random addr 459 offset := rand.Intn(len(servers)) 460 server := servers[offset] 461 s.peerLock.RUnlock() 462 463 // Forward to remote Nomad 464 metrics.IncrCounter([]string{"nomad", "rpc", "cross-region", region}, 1) 465 return s.connPool.RPC(region, server.Addr, server.MajorVersion, method, args, reply) 466 } 467 468 // streamingRpc creates a connection to the given server and conducts the 469 // initial handshake, returning the connection or an error. It is the callers 470 // responsibility to close the connection if there is no returned error. 471 func (s *Server) streamingRpc(server *serverParts, method string) (net.Conn, error) { 472 // Try to dial the server 473 conn, err := net.DialTimeout("tcp", server.Addr.String(), 10*time.Second) 474 if err != nil { 475 return nil, err 476 } 477 478 // Cast to TCPConn 479 if tcp, ok := conn.(*net.TCPConn); ok { 480 tcp.SetKeepAlive(true) 481 tcp.SetNoDelay(true) 482 } 483 484 if err := s.streamingRpcImpl(conn, server.Region, method); err != nil { 485 return nil, err 486 } 487 488 return conn, nil 489 } 490 491 // streamingRpcImpl takes a pre-established connection to a server and conducts 492 // the handshake to establish a streaming RPC for the given method. If an error 493 // is returned, the underlying connection has been closed. Otherwise it is 494 // assumed that the connection has been hijacked by the RPC method. 495 func (s *Server) streamingRpcImpl(conn net.Conn, region, method string) error { 496 // Check if TLS is enabled 497 s.tlsWrapLock.RLock() 498 tlsWrap := s.tlsWrap 499 s.tlsWrapLock.RUnlock() 500 501 if tlsWrap != nil { 502 // Switch the connection into TLS mode 503 if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil { 504 conn.Close() 505 return err 506 } 507 508 // Wrap the connection in a TLS client 509 tlsConn, err := tlsWrap(region, conn) 510 if err != nil { 511 conn.Close() 512 return err 513 } 514 conn = tlsConn 515 } 516 517 // Write the multiplex byte to set the mode 518 if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil { 519 conn.Close() 520 return err 521 } 522 523 // Send the header 524 encoder := codec.NewEncoder(conn, structs.MsgpackHandle) 525 decoder := codec.NewDecoder(conn, structs.MsgpackHandle) 526 header := structs.StreamingRpcHeader{ 527 Method: method, 528 } 529 if err := encoder.Encode(header); err != nil { 530 conn.Close() 531 return err 532 } 533 534 // Wait for the acknowledgement 535 var ack structs.StreamingRpcAck 536 if err := decoder.Decode(&ack); err != nil { 537 conn.Close() 538 return err 539 } 540 541 if ack.Error != "" { 542 conn.Close() 543 return errors.New(ack.Error) 544 } 545 546 return nil 547 } 548 549 // raftApplyFuture is used to encode a message, run it through raft, and return the Raft future. 550 func (s *Server) raftApplyFuture(t structs.MessageType, msg interface{}) (raft.ApplyFuture, error) { 551 buf, err := structs.Encode(t, msg) 552 if err != nil { 553 return nil, fmt.Errorf("Failed to encode request: %v", err) 554 } 555 556 // Warn if the command is very large 557 if n := len(buf); n > raftWarnSize { 558 s.logger.Printf("[WARN] nomad: Attempting to apply large raft entry (type %d) (%d bytes)", t, n) 559 } 560 561 future := s.raft.Apply(buf, enqueueLimit) 562 return future, nil 563 } 564 565 // raftApplyFn is the function signature for applying a msg to Raft 566 type raftApplyFn func(t structs.MessageType, msg interface{}) (interface{}, uint64, error) 567 568 // raftApply is used to encode a message, run it through raft, and return 569 // the FSM response along with any errors 570 func (s *Server) raftApply(t structs.MessageType, msg interface{}) (interface{}, uint64, error) { 571 future, err := s.raftApplyFuture(t, msg) 572 if err != nil { 573 return nil, 0, err 574 } 575 if err := future.Error(); err != nil { 576 return nil, 0, err 577 } 578 return future.Response(), future.Index(), nil 579 } 580 581 // setQueryMeta is used to populate the QueryMeta data for an RPC call 582 func (s *Server) setQueryMeta(m *structs.QueryMeta) { 583 if s.IsLeader() { 584 m.LastContact = 0 585 m.KnownLeader = true 586 } else { 587 m.LastContact = time.Now().Sub(s.raft.LastContact()) 588 m.KnownLeader = (s.raft.Leader() != "") 589 } 590 } 591 592 // queryFn is used to perform a query operation. If a re-query is needed, the 593 // passed-in watch set will be used to block for changes. The passed-in state 594 // store should be used (vs. calling fsm.State()) since the given state store 595 // will be correctly watched for changes if the state store is restored from 596 // a snapshot. 597 type queryFn func(memdb.WatchSet, *state.StateStore) error 598 599 // blockingOptions is used to parameterize blockingRPC 600 type blockingOptions struct { 601 queryOpts *structs.QueryOptions 602 queryMeta *structs.QueryMeta 603 run queryFn 604 } 605 606 // blockingRPC is used for queries that need to wait for a 607 // minimum index. This is used to block and wait for changes. 608 func (s *Server) blockingRPC(opts *blockingOptions) error { 609 ctx := context.Background() 610 var cancel context.CancelFunc 611 var state *state.StateStore 612 613 // Fast path non-blocking 614 if opts.queryOpts.MinQueryIndex == 0 { 615 goto RUN_QUERY 616 } 617 618 // Restrict the max query time, and ensure there is always one 619 if opts.queryOpts.MaxQueryTime > maxQueryTime { 620 opts.queryOpts.MaxQueryTime = maxQueryTime 621 } else if opts.queryOpts.MaxQueryTime <= 0 { 622 opts.queryOpts.MaxQueryTime = defaultQueryTime 623 } 624 625 // Apply a small amount of jitter to the request 626 opts.queryOpts.MaxQueryTime += lib.RandomStagger(opts.queryOpts.MaxQueryTime / structs.JitterFraction) 627 628 // Setup a query timeout 629 ctx, cancel = context.WithTimeout(context.Background(), opts.queryOpts.MaxQueryTime) 630 defer cancel() 631 632 RUN_QUERY: 633 // Update the query meta data 634 s.setQueryMeta(opts.queryMeta) 635 636 // Increment the rpc query counter 637 metrics.IncrCounter([]string{"nomad", "rpc", "query"}, 1) 638 639 // We capture the state store and its abandon channel but pass a snapshot to 640 // the blocking query function. We operate on the snapshot to allow separate 641 // calls to the state store not all wrapped within the same transaction. 642 state = s.fsm.State() 643 abandonCh := state.AbandonCh() 644 snap, _ := state.Snapshot() 645 stateSnap := &snap.StateStore 646 647 // We can skip all watch tracking if this isn't a blocking query. 648 var ws memdb.WatchSet 649 if opts.queryOpts.MinQueryIndex > 0 { 650 ws = memdb.NewWatchSet() 651 652 // This channel will be closed if a snapshot is restored and the 653 // whole state store is abandoned. 654 ws.Add(abandonCh) 655 } 656 657 // Block up to the timeout if we didn't see anything fresh. 658 err := opts.run(ws, stateSnap) 659 660 // Check for minimum query time 661 if err == nil && opts.queryOpts.MinQueryIndex > 0 && opts.queryMeta.Index <= opts.queryOpts.MinQueryIndex { 662 if err := ws.WatchCtx(ctx); err == nil { 663 goto RUN_QUERY 664 } 665 } 666 return err 667 }