github.com/bigcommerce/nomad@v0.9.3-bc/nomad/rpc.go (about) 1 package nomad 2 3 import ( 4 "context" 5 "crypto/tls" 6 "crypto/x509" 7 "errors" 8 "fmt" 9 "io" 10 "math/rand" 11 "net" 12 "net/rpc" 13 "strings" 14 "time" 15 16 golog "log" 17 18 metrics "github.com/armon/go-metrics" 19 log "github.com/hashicorp/go-hclog" 20 memdb "github.com/hashicorp/go-memdb" 21 22 "github.com/hashicorp/consul/lib" 23 "github.com/hashicorp/nomad/helper/pool" 24 "github.com/hashicorp/nomad/nomad/state" 25 "github.com/hashicorp/nomad/nomad/structs" 26 "github.com/hashicorp/raft" 27 "github.com/hashicorp/yamux" 28 "github.com/ugorji/go/codec" 29 ) 30 31 const ( 32 // maxQueryTime is used to bound the limit of a blocking query 33 maxQueryTime = 300 * time.Second 34 35 // defaultQueryTime is the amount of time we block waiting for a change 36 // if no time is specified. Previously we would wait the maxQueryTime. 37 defaultQueryTime = 300 * time.Second 38 39 // Warn if the Raft command is larger than this. 40 // If it's over 1MB something is probably being abusive. 41 raftWarnSize = 1024 * 1024 42 43 // enqueueLimit caps how long we will wait to enqueue 44 // a new Raft command. Something is probably wrong if this 45 // value is ever reached. However, it prevents us from blocking 46 // the requesting goroutine forever. 47 enqueueLimit = 30 * time.Second 48 ) 49 50 type rpcHandler struct { 51 *Server 52 logger log.Logger 53 gologger *golog.Logger 54 } 55 56 func newRpcHandler(s *Server) *rpcHandler { 57 logger := s.logger.Named("rpc") 58 return &rpcHandler{ 59 Server: s, 60 logger: logger, 61 gologger: logger.StandardLogger(&log.StandardLoggerOptions{InferLevels: true}), 62 } 63 } 64 65 // RPCContext provides metadata about the RPC connection. 66 type RPCContext struct { 67 // Conn exposes the raw connection. 68 Conn net.Conn 69 70 // Session exposes the multiplexed connection session. 71 Session *yamux.Session 72 73 // TLS marks whether the RPC is over a TLS based connection 74 TLS bool 75 76 // VerifiedChains is is the Verified certificates presented by the incoming 77 // connection. 78 VerifiedChains [][]*x509.Certificate 79 80 // NodeID marks the NodeID that initiated the connection. 81 NodeID string 82 } 83 84 // listen is used to listen for incoming RPC connections 85 func (r *rpcHandler) listen(ctx context.Context) { 86 defer close(r.listenerCh) 87 88 var acceptLoopDelay time.Duration 89 for { 90 select { 91 case <-ctx.Done(): 92 r.logger.Info("closing server RPC connection") 93 return 94 default: 95 } 96 97 // Accept a connection 98 conn, err := r.rpcListener.Accept() 99 if err != nil { 100 if r.shutdown { 101 return 102 } 103 r.handleAcceptErr(ctx, err, &acceptLoopDelay) 104 continue 105 } 106 // No error, reset loop delay 107 acceptLoopDelay = 0 108 109 go r.handleConn(ctx, conn, &RPCContext{Conn: conn}) 110 metrics.IncrCounter([]string{"nomad", "rpc", "accept_conn"}, 1) 111 } 112 } 113 114 // handleAcceptErr sleeps to avoid spamming the log, 115 // with a maximum delay according to whether or not the error is temporary 116 func (r *rpcHandler) handleAcceptErr(ctx context.Context, err error, loopDelay *time.Duration) { 117 const baseDelay = 5 * time.Millisecond 118 const maxDelayPerm = 5 * time.Second 119 const maxDelayTemp = 1 * time.Second 120 121 if *loopDelay == 0 { 122 *loopDelay = baseDelay 123 } else { 124 *loopDelay *= 2 125 } 126 127 temporaryError := false 128 if ne, ok := err.(net.Error); ok && ne.Temporary() { 129 temporaryError = true 130 } 131 132 if temporaryError && *loopDelay > maxDelayTemp { 133 *loopDelay = maxDelayTemp 134 } else if *loopDelay > maxDelayPerm { 135 *loopDelay = maxDelayPerm 136 } 137 138 r.logger.Error("failed to accept RPC conn", "error", err, "delay", *loopDelay) 139 140 select { 141 case <-ctx.Done(): 142 case <-time.After(*loopDelay): 143 } 144 } 145 146 // handleConn is used to determine if this is a Raft or 147 // Nomad type RPC connection and invoke the correct handler 148 func (r *rpcHandler) handleConn(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { 149 // Read a single byte 150 buf := make([]byte, 1) 151 if _, err := conn.Read(buf); err != nil { 152 if err != io.EOF { 153 r.logger.Error("failed to read first RPC byte", "error", err) 154 } 155 conn.Close() 156 return 157 } 158 159 // Enforce TLS if EnableRPC is set 160 if r.config.TLSConfig.EnableRPC && !rpcCtx.TLS && pool.RPCType(buf[0]) != pool.RpcTLS { 161 if !r.config.TLSConfig.RPCUpgradeMode { 162 r.logger.Warn("non-TLS connection attempted with RequireTLS set", "remote_addr", conn.RemoteAddr()) 163 conn.Close() 164 return 165 } 166 } 167 168 // Switch on the byte 169 switch pool.RPCType(buf[0]) { 170 case pool.RpcNomad: 171 // Create an RPC Server and handle the request 172 server := rpc.NewServer() 173 r.setupRpcServer(server, rpcCtx) 174 r.handleNomadConn(ctx, conn, server) 175 176 // Remove any potential mapping between a NodeID to this connection and 177 // close the underlying connection. 178 r.removeNodeConn(rpcCtx) 179 180 case pool.RpcRaft: 181 metrics.IncrCounter([]string{"nomad", "rpc", "raft_handoff"}, 1) 182 r.raftLayer.Handoff(ctx, conn) 183 184 case pool.RpcMultiplex: 185 r.handleMultiplex(ctx, conn, rpcCtx) 186 187 case pool.RpcTLS: 188 if r.rpcTLS == nil { 189 r.logger.Warn("TLS connection attempted, server not configured for TLS") 190 conn.Close() 191 return 192 } 193 conn = tls.Server(conn, r.rpcTLS) 194 195 // Force a handshake so we can get information about the TLS connection 196 // state. 197 tlsConn, ok := conn.(*tls.Conn) 198 if !ok { 199 r.logger.Error("expected TLS connection", "got", log.Fmt("%T", conn)) 200 conn.Close() 201 return 202 } 203 204 if err := tlsConn.Handshake(); err != nil { 205 r.logger.Warn("failed TLS handshake", "remote_addr", tlsConn.RemoteAddr(), "error", err) 206 conn.Close() 207 return 208 } 209 210 // Update the connection context with the fact that the connection is 211 // using TLS 212 rpcCtx.TLS = true 213 214 // Store the verified chains so they can be inspected later. 215 state := tlsConn.ConnectionState() 216 rpcCtx.VerifiedChains = state.VerifiedChains 217 218 r.handleConn(ctx, conn, rpcCtx) 219 220 case pool.RpcStreaming: 221 r.handleStreamingConn(conn) 222 223 case pool.RpcMultiplexV2: 224 r.handleMultiplexV2(ctx, conn, rpcCtx) 225 226 default: 227 r.logger.Error("unrecognized RPC byte", "byte", buf[0]) 228 conn.Close() 229 return 230 } 231 } 232 233 // handleMultiplex is used to multiplex a single incoming connection 234 // using the Yamux multiplexer 235 func (r *rpcHandler) handleMultiplex(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { 236 defer func() { 237 // Remove any potential mapping between a NodeID to this connection and 238 // close the underlying connection. 239 r.removeNodeConn(rpcCtx) 240 conn.Close() 241 }() 242 243 conf := yamux.DefaultConfig() 244 conf.LogOutput = nil 245 conf.Logger = r.gologger 246 server, err := yamux.Server(conn, conf) 247 if err != nil { 248 r.logger.Error("multiplex failed to create yamux server", "error", err) 249 return 250 } 251 252 // Update the context to store the yamux session 253 rpcCtx.Session = server 254 255 // Create the RPC server for this connection 256 rpcServer := rpc.NewServer() 257 r.setupRpcServer(rpcServer, rpcCtx) 258 259 for { 260 // stop handling connections if context was cancelled 261 if ctx.Err() != nil { 262 return 263 } 264 265 sub, err := server.Accept() 266 if err != nil { 267 if err != io.EOF { 268 r.logger.Error("multiplex conn accept failed", "error", err) 269 } 270 return 271 } 272 go r.handleNomadConn(ctx, sub, rpcServer) 273 } 274 } 275 276 // handleNomadConn is used to service a single Nomad RPC connection 277 func (r *rpcHandler) handleNomadConn(ctx context.Context, conn net.Conn, server *rpc.Server) { 278 defer conn.Close() 279 rpcCodec := pool.NewServerCodec(conn) 280 for { 281 select { 282 case <-ctx.Done(): 283 r.logger.Info("closing server RPC connection") 284 return 285 case <-r.shutdownCh: 286 return 287 default: 288 } 289 290 if err := server.ServeRequest(rpcCodec); err != nil { 291 if err != io.EOF && !strings.Contains(err.Error(), "closed") { 292 r.logger.Error("RPC error", "error", err, "connection", conn) 293 metrics.IncrCounter([]string{"nomad", "rpc", "request_error"}, 1) 294 } 295 return 296 } 297 metrics.IncrCounter([]string{"nomad", "rpc", "request"}, 1) 298 } 299 } 300 301 // handleStreamingConn is used to handle a single Streaming Nomad RPC connection. 302 func (r *rpcHandler) handleStreamingConn(conn net.Conn) { 303 defer conn.Close() 304 305 // Decode the header 306 var header structs.StreamingRpcHeader 307 decoder := codec.NewDecoder(conn, structs.MsgpackHandle) 308 if err := decoder.Decode(&header); err != nil { 309 if err != io.EOF && !strings.Contains(err.Error(), "closed") { 310 r.logger.Error("streaming RPC error", "error", err, "connection", conn) 311 metrics.IncrCounter([]string{"nomad", "streaming_rpc", "request_error"}, 1) 312 } 313 314 return 315 } 316 317 ack := structs.StreamingRpcAck{} 318 handler, err := r.streamingRpcs.GetHandler(header.Method) 319 if err != nil { 320 r.logger.Error("streaming RPC error", "error", err, "connection", conn) 321 metrics.IncrCounter([]string{"nomad", "streaming_rpc", "request_error"}, 1) 322 ack.Error = err.Error() 323 } 324 325 // Send the acknowledgement 326 encoder := codec.NewEncoder(conn, structs.MsgpackHandle) 327 if err := encoder.Encode(ack); err != nil { 328 conn.Close() 329 return 330 } 331 332 if ack.Error != "" { 333 return 334 } 335 336 // Invoke the handler 337 metrics.IncrCounter([]string{"nomad", "streaming_rpc", "request"}, 1) 338 handler(conn) 339 } 340 341 // handleMultiplexV2 is used to multiplex a single incoming connection 342 // using the Yamux multiplexer. Version 2 handling allows a single connection to 343 // switch streams between regulars RPCs and Streaming RPCs. 344 func (r *rpcHandler) handleMultiplexV2(ctx context.Context, conn net.Conn, rpcCtx *RPCContext) { 345 defer func() { 346 // Remove any potential mapping between a NodeID to this connection and 347 // close the underlying connection. 348 r.removeNodeConn(rpcCtx) 349 conn.Close() 350 }() 351 352 conf := yamux.DefaultConfig() 353 conf.LogOutput = nil 354 conf.Logger = r.gologger 355 server, err := yamux.Server(conn, conf) 356 if err != nil { 357 r.logger.Error("multiplex_v2 failed to create yamux server", "error", err) 358 return 359 } 360 361 // Update the context to store the yamux session 362 rpcCtx.Session = server 363 364 // Create the RPC server for this connection 365 rpcServer := rpc.NewServer() 366 r.setupRpcServer(rpcServer, rpcCtx) 367 368 for { 369 // stop handling connections if context was cancelled 370 if ctx.Err() != nil { 371 return 372 } 373 374 // Accept a new stream 375 sub, err := server.Accept() 376 if err != nil { 377 if err != io.EOF { 378 r.logger.Error("multiplex_v2 conn accept failed", "error", err) 379 } 380 return 381 } 382 383 // Read a single byte 384 buf := make([]byte, 1) 385 if _, err := sub.Read(buf); err != nil { 386 if err != io.EOF { 387 r.logger.Error("multiplex_v2 failed to read first byte", "error", err) 388 } 389 return 390 } 391 392 // Determine which handler to use 393 switch pool.RPCType(buf[0]) { 394 case pool.RpcNomad: 395 go r.handleNomadConn(ctx, sub, rpcServer) 396 case pool.RpcStreaming: 397 go r.handleStreamingConn(sub) 398 399 default: 400 r.logger.Error("multiplex_v2 unrecognized first RPC byte", "byte", buf[0]) 401 return 402 } 403 } 404 405 } 406 407 // forward is used to forward to a remote region or to forward to the local leader 408 // Returns a bool of if forwarding was performed, as well as any error 409 func (r *rpcHandler) forward(method string, info structs.RPCInfo, args interface{}, reply interface{}) (bool, error) { 410 var firstCheck time.Time 411 412 region := info.RequestRegion() 413 if region == "" { 414 return true, fmt.Errorf("missing target RPC") 415 } 416 417 // Handle region forwarding 418 if region != r.config.Region { 419 // Mark that we are forwarding the RPC 420 info.SetForwarded() 421 err := r.forwardRegion(region, method, args, reply) 422 return true, err 423 } 424 425 // Check if we can allow a stale read 426 if info.IsRead() && info.AllowStaleRead() { 427 return false, nil 428 } 429 430 CHECK_LEADER: 431 // Find the leader 432 isLeader, remoteServer := r.getLeader() 433 434 // Handle the case we are the leader 435 if isLeader { 436 return false, nil 437 } 438 439 // Handle the case of a known leader 440 if remoteServer != nil { 441 // Mark that we are forwarding the RPC 442 info.SetForwarded() 443 err := r.forwardLeader(remoteServer, method, args, reply) 444 return true, err 445 } 446 447 // Gate the request until there is a leader 448 if firstCheck.IsZero() { 449 firstCheck = time.Now() 450 } 451 if time.Now().Sub(firstCheck) < r.config.RPCHoldTimeout { 452 jitter := lib.RandomStagger(r.config.RPCHoldTimeout / structs.JitterFraction) 453 select { 454 case <-time.After(jitter): 455 goto CHECK_LEADER 456 case <-r.shutdownCh: 457 } 458 } 459 460 // No leader found and hold time exceeded 461 return true, structs.ErrNoLeader 462 } 463 464 // getLeader returns if the current node is the leader, and if not 465 // then it returns the leader which is potentially nil if the cluster 466 // has not yet elected a leader. 467 func (s *Server) getLeader() (bool, *serverParts) { 468 // Check if we are the leader 469 if s.IsLeader() { 470 return true, nil 471 } 472 473 // Get the leader 474 leader := s.raft.Leader() 475 if leader == "" { 476 return false, nil 477 } 478 479 // Lookup the server 480 s.peerLock.RLock() 481 server := s.localPeers[leader] 482 s.peerLock.RUnlock() 483 484 // Server could be nil 485 return false, server 486 } 487 488 // forwardLeader is used to forward an RPC call to the leader, or fail if no leader 489 func (r *rpcHandler) forwardLeader(server *serverParts, method string, args interface{}, reply interface{}) error { 490 // Handle a missing server 491 if server == nil { 492 return structs.ErrNoLeader 493 } 494 return r.connPool.RPC(r.config.Region, server.Addr, server.MajorVersion, method, args, reply) 495 } 496 497 // forwardServer is used to forward an RPC call to a particular server 498 func (r *rpcHandler) forwardServer(server *serverParts, method string, args interface{}, reply interface{}) error { 499 // Handle a missing server 500 if server == nil { 501 return errors.New("must be given a valid server address") 502 } 503 return r.connPool.RPC(r.config.Region, server.Addr, server.MajorVersion, method, args, reply) 504 } 505 506 // forwardRegion is used to forward an RPC call to a remote region, or fail if no servers 507 func (r *rpcHandler) forwardRegion(region, method string, args interface{}, reply interface{}) error { 508 // Bail if we can't find any servers 509 r.peerLock.RLock() 510 servers := r.peers[region] 511 if len(servers) == 0 { 512 r.peerLock.RUnlock() 513 r.logger.Warn("no path found to region", "region", region) 514 return structs.ErrNoRegionPath 515 } 516 517 // Select a random addr 518 offset := rand.Intn(len(servers)) 519 server := servers[offset] 520 r.peerLock.RUnlock() 521 522 // Forward to remote Nomad 523 metrics.IncrCounter([]string{"nomad", "rpc", "cross-region", region}, 1) 524 return r.connPool.RPC(region, server.Addr, server.MajorVersion, method, args, reply) 525 } 526 527 // streamingRpc creates a connection to the given server and conducts the 528 // initial handshake, returning the connection or an error. It is the callers 529 // responsibility to close the connection if there is no returned error. 530 func (r *rpcHandler) streamingRpc(server *serverParts, method string) (net.Conn, error) { 531 // Try to dial the server 532 conn, err := net.DialTimeout("tcp", server.Addr.String(), 10*time.Second) 533 if err != nil { 534 return nil, err 535 } 536 537 // Cast to TCPConn 538 if tcp, ok := conn.(*net.TCPConn); ok { 539 tcp.SetKeepAlive(true) 540 tcp.SetNoDelay(true) 541 } 542 543 if err := r.streamingRpcImpl(conn, server.Region, method); err != nil { 544 return nil, err 545 } 546 547 return conn, nil 548 } 549 550 // streamingRpcImpl takes a pre-established connection to a server and conducts 551 // the handshake to establish a streaming RPC for the given method. If an error 552 // is returned, the underlying connection has been closed. Otherwise it is 553 // assumed that the connection has been hijacked by the RPC method. 554 func (r *rpcHandler) streamingRpcImpl(conn net.Conn, region, method string) error { 555 // Check if TLS is enabled 556 r.tlsWrapLock.RLock() 557 tlsWrap := r.tlsWrap 558 r.tlsWrapLock.RUnlock() 559 560 if tlsWrap != nil { 561 // Switch the connection into TLS mode 562 if _, err := conn.Write([]byte{byte(pool.RpcTLS)}); err != nil { 563 conn.Close() 564 return err 565 } 566 567 // Wrap the connection in a TLS client 568 tlsConn, err := tlsWrap(region, conn) 569 if err != nil { 570 conn.Close() 571 return err 572 } 573 conn = tlsConn 574 } 575 576 // Write the multiplex byte to set the mode 577 if _, err := conn.Write([]byte{byte(pool.RpcStreaming)}); err != nil { 578 conn.Close() 579 return err 580 } 581 582 // Send the header 583 encoder := codec.NewEncoder(conn, structs.MsgpackHandle) 584 decoder := codec.NewDecoder(conn, structs.MsgpackHandle) 585 header := structs.StreamingRpcHeader{ 586 Method: method, 587 } 588 if err := encoder.Encode(header); err != nil { 589 conn.Close() 590 return err 591 } 592 593 // Wait for the acknowledgement 594 var ack structs.StreamingRpcAck 595 if err := decoder.Decode(&ack); err != nil { 596 conn.Close() 597 return err 598 } 599 600 if ack.Error != "" { 601 conn.Close() 602 return errors.New(ack.Error) 603 } 604 605 return nil 606 } 607 608 // raftApplyFuture is used to encode a message, run it through raft, and return the Raft future. 609 func (s *Server) raftApplyFuture(t structs.MessageType, msg interface{}) (raft.ApplyFuture, error) { 610 buf, err := structs.Encode(t, msg) 611 if err != nil { 612 return nil, fmt.Errorf("Failed to encode request: %v", err) 613 } 614 615 // Warn if the command is very large 616 if n := len(buf); n > raftWarnSize { 617 s.logger.Warn("attempting to apply large raft entry", "raft_type", t, "bytes", n) 618 } 619 620 future := s.raft.Apply(buf, enqueueLimit) 621 return future, nil 622 } 623 624 // raftApplyFn is the function signature for applying a msg to Raft 625 type raftApplyFn func(t structs.MessageType, msg interface{}) (interface{}, uint64, error) 626 627 // raftApply is used to encode a message, run it through raft, and return 628 // the FSM response along with any errors 629 func (s *Server) raftApply(t structs.MessageType, msg interface{}) (interface{}, uint64, error) { 630 future, err := s.raftApplyFuture(t, msg) 631 if err != nil { 632 return nil, 0, err 633 } 634 if err := future.Error(); err != nil { 635 return nil, 0, err 636 } 637 return future.Response(), future.Index(), nil 638 } 639 640 // setQueryMeta is used to populate the QueryMeta data for an RPC call 641 func (r *rpcHandler) setQueryMeta(m *structs.QueryMeta) { 642 if r.IsLeader() { 643 m.LastContact = 0 644 m.KnownLeader = true 645 } else { 646 m.LastContact = time.Now().Sub(r.raft.LastContact()) 647 m.KnownLeader = (r.raft.Leader() != "") 648 } 649 } 650 651 // queryFn is used to perform a query operation. If a re-query is needed, the 652 // passed-in watch set will be used to block for changes. The passed-in state 653 // store should be used (vs. calling fsm.State()) since the given state store 654 // will be correctly watched for changes if the state store is restored from 655 // a snapshot. 656 type queryFn func(memdb.WatchSet, *state.StateStore) error 657 658 // blockingOptions is used to parameterize blockingRPC 659 type blockingOptions struct { 660 queryOpts *structs.QueryOptions 661 queryMeta *structs.QueryMeta 662 run queryFn 663 } 664 665 // blockingRPC is used for queries that need to wait for a 666 // minimum index. This is used to block and wait for changes. 667 func (r *rpcHandler) blockingRPC(opts *blockingOptions) error { 668 ctx := context.Background() 669 var cancel context.CancelFunc 670 var state *state.StateStore 671 672 // Fast path non-blocking 673 if opts.queryOpts.MinQueryIndex == 0 { 674 goto RUN_QUERY 675 } 676 677 // Restrict the max query time, and ensure there is always one 678 if opts.queryOpts.MaxQueryTime > maxQueryTime { 679 opts.queryOpts.MaxQueryTime = maxQueryTime 680 } else if opts.queryOpts.MaxQueryTime <= 0 { 681 opts.queryOpts.MaxQueryTime = defaultQueryTime 682 } 683 684 // Apply a small amount of jitter to the request 685 opts.queryOpts.MaxQueryTime += lib.RandomStagger(opts.queryOpts.MaxQueryTime / structs.JitterFraction) 686 687 // Setup a query timeout 688 ctx, cancel = context.WithTimeout(context.Background(), opts.queryOpts.MaxQueryTime) 689 defer cancel() 690 691 RUN_QUERY: 692 // Update the query meta data 693 r.setQueryMeta(opts.queryMeta) 694 695 // Increment the rpc query counter 696 metrics.IncrCounter([]string{"nomad", "rpc", "query"}, 1) 697 698 // We capture the state store and its abandon channel but pass a snapshot to 699 // the blocking query function. We operate on the snapshot to allow separate 700 // calls to the state store not all wrapped within the same transaction. 701 state = r.fsm.State() 702 abandonCh := state.AbandonCh() 703 snap, _ := state.Snapshot() 704 stateSnap := &snap.StateStore 705 706 // We can skip all watch tracking if this isn't a blocking query. 707 var ws memdb.WatchSet 708 if opts.queryOpts.MinQueryIndex > 0 { 709 ws = memdb.NewWatchSet() 710 711 // This channel will be closed if a snapshot is restored and the 712 // whole state store is abandoned. 713 ws.Add(abandonCh) 714 } 715 716 // Block up to the timeout if we didn't see anything fresh. 717 err := opts.run(ws, stateSnap) 718 719 // Check for minimum query time 720 if err == nil && opts.queryOpts.MinQueryIndex > 0 && opts.queryMeta.Index <= opts.queryOpts.MinQueryIndex { 721 if err := ws.WatchCtx(ctx); err == nil { 722 goto RUN_QUERY 723 } 724 } 725 return err 726 }