github.com/clly/consul@v1.4.5/agent/consul/rpc.go (about)

     1  package consul
     2  
     3  import (
     4  	"crypto/tls"
     5  	"fmt"
     6  	"io"
     7  	"net"
     8  	"strings"
     9  	"time"
    10  
    11  	"github.com/armon/go-metrics"
    12  	"github.com/hashicorp/consul/agent/consul/state"
    13  	"github.com/hashicorp/consul/agent/metadata"
    14  	"github.com/hashicorp/consul/agent/pool"
    15  	"github.com/hashicorp/consul/agent/structs"
    16  	"github.com/hashicorp/consul/lib"
    17  	memdb "github.com/hashicorp/go-memdb"
    18  	"github.com/hashicorp/memberlist"
    19  	msgpackrpc "github.com/hashicorp/net-rpc-msgpackrpc"
    20  	"github.com/hashicorp/yamux"
    21  )
    22  
    23  const (
    24  	// maxQueryTime is used to bound the limit of a blocking query
    25  	maxQueryTime = 600 * time.Second
    26  
    27  	// defaultQueryTime is the amount of time we block waiting for a change
    28  	// if no time is specified. Previously we would wait the maxQueryTime.
    29  	defaultQueryTime = 300 * time.Second
    30  
    31  	// jitterFraction is a the limit to the amount of jitter we apply
    32  	// to a user specified MaxQueryTime. We divide the specified time by
    33  	// the fraction. So 16 == 6.25% limit of jitter. This same fraction
    34  	// is applied to the RPCHoldTimeout
    35  	jitterFraction = 16
    36  
    37  	// Warn if the Raft command is larger than this.
    38  	// If it's over 1MB something is probably being abusive.
    39  	raftWarnSize = 1024 * 1024
    40  
    41  	// enqueueLimit caps how long we will wait to enqueue
    42  	// a new Raft command. Something is probably wrong if this
    43  	// value is ever reached. However, it prevents us from blocking
    44  	// the requesting goroutine forever.
    45  	enqueueLimit = 30 * time.Second
    46  )
    47  
    48  // listen is used to listen for incoming RPC connections
    49  func (s *Server) listen(listener net.Listener) {
    50  	for {
    51  		// Accept a connection
    52  		conn, err := listener.Accept()
    53  		if err != nil {
    54  			if s.shutdown {
    55  				return
    56  			}
    57  			s.logger.Printf("[ERR] consul.rpc: failed to accept RPC conn: %v", err)
    58  			continue
    59  		}
    60  
    61  		go s.handleConn(conn, false)
    62  		metrics.IncrCounter([]string{"rpc", "accept_conn"}, 1)
    63  	}
    64  }
    65  
    66  // logConn is a wrapper around memberlist's LogConn so that we format references
    67  // to "from" addresses in a consistent way. This is just a shorter name.
    68  func logConn(conn net.Conn) string {
    69  	return memberlist.LogConn(conn)
    70  }
    71  
    72  // handleConn is used to determine if this is a Raft or
    73  // Consul type RPC connection and invoke the correct handler
    74  func (s *Server) handleConn(conn net.Conn, isTLS bool) {
    75  	// Read a single byte
    76  	buf := make([]byte, 1)
    77  	if _, err := conn.Read(buf); err != nil {
    78  		if err != io.EOF {
    79  			s.logger.Printf("[ERR] consul.rpc: failed to read byte: %v %s", err, logConn(conn))
    80  		}
    81  		conn.Close()
    82  		return
    83  	}
    84  	typ := pool.RPCType(buf[0])
    85  
    86  	// Enforce TLS if VerifyIncoming is set
    87  	if s.config.VerifyIncoming && !isTLS && typ != pool.RPCTLS {
    88  		s.logger.Printf("[WARN] consul.rpc: Non-TLS connection attempted with VerifyIncoming set %s", logConn(conn))
    89  		conn.Close()
    90  		return
    91  	}
    92  
    93  	// Switch on the byte
    94  	switch typ {
    95  	case pool.RPCConsul:
    96  		s.handleConsulConn(conn)
    97  
    98  	case pool.RPCRaft:
    99  		metrics.IncrCounter([]string{"rpc", "raft_handoff"}, 1)
   100  		s.raftLayer.Handoff(conn)
   101  
   102  	case pool.RPCTLS:
   103  		if s.rpcTLS == nil {
   104  			s.logger.Printf("[WARN] consul.rpc: TLS connection attempted, server not configured for TLS %s", logConn(conn))
   105  			conn.Close()
   106  			return
   107  		}
   108  		conn = tls.Server(conn, s.rpcTLS)
   109  		s.handleConn(conn, true)
   110  
   111  	case pool.RPCMultiplexV2:
   112  		s.handleMultiplexV2(conn)
   113  
   114  	case pool.RPCSnapshot:
   115  		s.handleSnapshotConn(conn)
   116  
   117  	default:
   118  		if !s.handleEnterpriseRPCConn(typ, conn, isTLS) {
   119  			s.logger.Printf("[ERR] consul.rpc: unrecognized RPC byte: %v %s", typ, logConn(conn))
   120  			conn.Close()
   121  		}
   122  	}
   123  }
   124  
   125  // handleMultiplexV2 is used to multiplex a single incoming connection
   126  // using the Yamux multiplexer
   127  func (s *Server) handleMultiplexV2(conn net.Conn) {
   128  	defer conn.Close()
   129  	conf := yamux.DefaultConfig()
   130  	conf.LogOutput = s.config.LogOutput
   131  	server, _ := yamux.Server(conn, conf)
   132  	for {
   133  		sub, err := server.Accept()
   134  		if err != nil {
   135  			if err != io.EOF {
   136  				s.logger.Printf("[ERR] consul.rpc: multiplex conn accept failed: %v %s", err, logConn(conn))
   137  			}
   138  			return
   139  		}
   140  		go s.handleConsulConn(sub)
   141  	}
   142  }
   143  
   144  // handleConsulConn is used to service a single Consul RPC connection
   145  func (s *Server) handleConsulConn(conn net.Conn) {
   146  	defer conn.Close()
   147  	rpcCodec := msgpackrpc.NewServerCodec(conn)
   148  	for {
   149  		select {
   150  		case <-s.shutdownCh:
   151  			return
   152  		default:
   153  		}
   154  
   155  		if err := s.rpcServer.ServeRequest(rpcCodec); err != nil {
   156  			if err != io.EOF && !strings.Contains(err.Error(), "closed") {
   157  				s.logger.Printf("[ERR] consul.rpc: RPC error: %v %s", err, logConn(conn))
   158  				metrics.IncrCounter([]string{"rpc", "request_error"}, 1)
   159  			}
   160  			return
   161  		}
   162  		metrics.IncrCounter([]string{"rpc", "request"}, 1)
   163  	}
   164  }
   165  
   166  // handleSnapshotConn is used to dispatch snapshot saves and restores, which
   167  // stream so don't use the normal RPC mechanism.
   168  func (s *Server) handleSnapshotConn(conn net.Conn) {
   169  	go func() {
   170  		defer conn.Close()
   171  		if err := s.handleSnapshotRequest(conn); err != nil {
   172  			s.logger.Printf("[ERR] consul.rpc: Snapshot RPC error: %v %s", err, logConn(conn))
   173  		}
   174  	}()
   175  }
   176  
   177  // canRetry returns true if the given situation is safe for a retry.
   178  func canRetry(args interface{}, err error) bool {
   179  	// No leader errors are always safe to retry since no state could have
   180  	// been changed.
   181  	if structs.IsErrNoLeader(err) {
   182  		return true
   183  	}
   184  
   185  	// Reads are safe to retry for stream errors, such as if a server was
   186  	// being shut down.
   187  	info, ok := args.(structs.RPCInfo)
   188  	if ok && info.IsRead() && lib.IsErrEOF(err) {
   189  		return true
   190  	}
   191  
   192  	return false
   193  }
   194  
   195  // forward is used to forward to a remote DC or to forward to the local leader
   196  // Returns a bool of if forwarding was performed, as well as any error
   197  func (s *Server) forward(method string, info structs.RPCInfo, args interface{}, reply interface{}) (bool, error) {
   198  	var firstCheck time.Time
   199  
   200  	// Handle DC forwarding
   201  	dc := info.RequestDatacenter()
   202  	if dc != s.config.Datacenter {
   203  		err := s.forwardDC(method, dc, args, reply)
   204  		return true, err
   205  	}
   206  
   207  	// Check if we can allow a stale read, ensure our local DB is initialized
   208  	if info.IsRead() && info.AllowStaleRead() && !s.raft.LastContact().IsZero() {
   209  		return false, nil
   210  	}
   211  
   212  CHECK_LEADER:
   213  	// Fail fast if we are in the process of leaving
   214  	select {
   215  	case <-s.leaveCh:
   216  		return true, structs.ErrNoLeader
   217  	default:
   218  	}
   219  
   220  	// Find the leader
   221  	isLeader, leader := s.getLeader()
   222  
   223  	// Handle the case we are the leader
   224  	if isLeader {
   225  		return false, nil
   226  	}
   227  
   228  	// Handle the case of a known leader
   229  	rpcErr := structs.ErrNoLeader
   230  	if leader != nil {
   231  		rpcErr = s.connPool.RPC(s.config.Datacenter, leader.Addr,
   232  			leader.Version, method, leader.UseTLS, args, reply)
   233  		if rpcErr != nil && canRetry(info, rpcErr) {
   234  			goto RETRY
   235  		}
   236  		return true, rpcErr
   237  	}
   238  
   239  RETRY:
   240  	// Gate the request until there is a leader
   241  	if firstCheck.IsZero() {
   242  		firstCheck = time.Now()
   243  	}
   244  	if time.Since(firstCheck) < s.config.RPCHoldTimeout {
   245  		jitter := lib.RandomStagger(s.config.RPCHoldTimeout / jitterFraction)
   246  		select {
   247  		case <-time.After(jitter):
   248  			goto CHECK_LEADER
   249  		case <-s.leaveCh:
   250  		case <-s.shutdownCh:
   251  		}
   252  	}
   253  
   254  	// No leader found and hold time exceeded
   255  	return true, rpcErr
   256  }
   257  
   258  // getLeader returns if the current node is the leader, and if not then it
   259  // returns the leader which is potentially nil if the cluster has not yet
   260  // elected a leader.
   261  func (s *Server) getLeader() (bool, *metadata.Server) {
   262  	// Check if we are the leader
   263  	if s.IsLeader() {
   264  		return true, nil
   265  	}
   266  
   267  	// Get the leader
   268  	leader := s.raft.Leader()
   269  	if leader == "" {
   270  		return false, nil
   271  	}
   272  
   273  	// Lookup the server
   274  	server := s.serverLookup.Server(leader)
   275  
   276  	// Server could be nil
   277  	return false, server
   278  }
   279  
   280  // forwardDC is used to forward an RPC call to a remote DC, or fail if no servers
   281  func (s *Server) forwardDC(method, dc string, args interface{}, reply interface{}) error {
   282  	manager, server, ok := s.router.FindRoute(dc)
   283  	if !ok {
   284  		s.logger.Printf("[WARN] consul.rpc: RPC request for DC %q, no path found", dc)
   285  		return structs.ErrNoDCPath
   286  	}
   287  
   288  	metrics.IncrCounterWithLabels([]string{"rpc", "cross-dc"}, 1,
   289  		[]metrics.Label{{Name: "datacenter", Value: dc}})
   290  	if err := s.connPool.RPC(dc, server.Addr, server.Version, method, server.UseTLS, args, reply); err != nil {
   291  		manager.NotifyFailedServer(server)
   292  		s.logger.Printf("[ERR] consul: RPC failed to server %s in DC %q: %v", server.Addr, dc, err)
   293  		return err
   294  	}
   295  
   296  	return nil
   297  }
   298  
   299  // globalRPC is used to forward an RPC request to one server in each datacenter.
   300  // This will only error for RPC-related errors. Otherwise, application-level
   301  // errors can be sent in the response objects.
   302  func (s *Server) globalRPC(method string, args interface{},
   303  	reply structs.CompoundResponse) error {
   304  
   305  	// Make a new request into each datacenter
   306  	dcs := s.router.GetDatacenters()
   307  
   308  	replies, total := 0, len(dcs)
   309  	errorCh := make(chan error, total)
   310  	respCh := make(chan interface{}, total)
   311  
   312  	for _, dc := range dcs {
   313  		go func(dc string) {
   314  			rr := reply.New()
   315  			if err := s.forwardDC(method, dc, args, &rr); err != nil {
   316  				errorCh <- err
   317  				return
   318  			}
   319  			respCh <- rr
   320  		}(dc)
   321  	}
   322  
   323  	for replies < total {
   324  		select {
   325  		case err := <-errorCh:
   326  			return err
   327  		case rr := <-respCh:
   328  			reply.Add(rr)
   329  			replies++
   330  		}
   331  	}
   332  	return nil
   333  }
   334  
   335  // raftApply is used to encode a message, run it through raft, and return
   336  // the FSM response along with any errors
   337  func (s *Server) raftApply(t structs.MessageType, msg interface{}) (interface{}, error) {
   338  	buf, err := structs.Encode(t, msg)
   339  	if err != nil {
   340  		return nil, fmt.Errorf("Failed to encode request: %v", err)
   341  	}
   342  
   343  	// Warn if the command is very large
   344  	if n := len(buf); n > raftWarnSize {
   345  		s.logger.Printf("[WARN] consul: Attempting to apply large raft entry (%d bytes)", n)
   346  	}
   347  
   348  	future := s.raft.Apply(buf, enqueueLimit)
   349  	if err := future.Error(); err != nil {
   350  		return nil, err
   351  	}
   352  
   353  	return future.Response(), nil
   354  }
   355  
   356  // queryFn is used to perform a query operation. If a re-query is needed, the
   357  // passed-in watch set will be used to block for changes. The passed-in state
   358  // store should be used (vs. calling fsm.State()) since the given state store
   359  // will be correctly watched for changes if the state store is restored from
   360  // a snapshot.
   361  type queryFn func(memdb.WatchSet, *state.Store) error
   362  
   363  // blockingQuery is used to process a potentially blocking query operation.
   364  func (s *Server) blockingQuery(queryOpts *structs.QueryOptions, queryMeta *structs.QueryMeta,
   365  	fn queryFn) error {
   366  	var timeout *time.Timer
   367  
   368  	// Fast path right to the non-blocking query.
   369  	if queryOpts.MinQueryIndex == 0 {
   370  		goto RUN_QUERY
   371  	}
   372  
   373  	// Restrict the max query time, and ensure there is always one.
   374  	if queryOpts.MaxQueryTime > maxQueryTime {
   375  		queryOpts.MaxQueryTime = maxQueryTime
   376  	} else if queryOpts.MaxQueryTime <= 0 {
   377  		queryOpts.MaxQueryTime = defaultQueryTime
   378  	}
   379  
   380  	// Apply a small amount of jitter to the request.
   381  	queryOpts.MaxQueryTime += lib.RandomStagger(queryOpts.MaxQueryTime / jitterFraction)
   382  
   383  	// Setup a query timeout.
   384  	timeout = time.NewTimer(queryOpts.MaxQueryTime)
   385  	defer timeout.Stop()
   386  
   387  RUN_QUERY:
   388  	// Update the query metadata.
   389  	s.setQueryMeta(queryMeta)
   390  
   391  	// If the read must be consistent we verify that we are still the leader.
   392  	if queryOpts.RequireConsistent {
   393  		if err := s.consistentRead(); err != nil {
   394  			return err
   395  		}
   396  	}
   397  
   398  	// Run the query.
   399  	metrics.IncrCounter([]string{"rpc", "query"}, 1)
   400  
   401  	// Operate on a consistent set of state. This makes sure that the
   402  	// abandon channel goes with the state that the caller is using to
   403  	// build watches.
   404  	state := s.fsm.State()
   405  
   406  	// We can skip all watch tracking if this isn't a blocking query.
   407  	var ws memdb.WatchSet
   408  	if queryOpts.MinQueryIndex > 0 {
   409  		ws = memdb.NewWatchSet()
   410  
   411  		// This channel will be closed if a snapshot is restored and the
   412  		// whole state store is abandoned.
   413  		ws.Add(state.AbandonCh())
   414  	}
   415  
   416  	// Block up to the timeout if we didn't see anything fresh.
   417  	err := fn(ws, state)
   418  	// Note we check queryOpts.MinQueryIndex is greater than zero to determine if
   419  	// blocking was requested by client, NOT meta.Index since the state function
   420  	// might return zero if something is not initialized and care wasn't taken to
   421  	// handle that special case (in practice this happened a lot so fixing it
   422  	// systematically here beats trying to remember to add zero checks in every
   423  	// state method). We also need to ensure that unless there is an error, we
   424  	// return an index > 0 otherwise the client will never block and burn CPU and
   425  	// requests.
   426  	if err == nil && queryMeta.Index < 1 {
   427  		queryMeta.Index = 1
   428  	}
   429  	if err == nil && queryOpts.MinQueryIndex > 0 && queryMeta.Index <= queryOpts.MinQueryIndex {
   430  		if expired := ws.Watch(timeout.C); !expired {
   431  			// If a restore may have woken us up then bail out from
   432  			// the query immediately. This is slightly race-ey since
   433  			// this might have been interrupted for other reasons,
   434  			// but it's OK to kick it back to the caller in either
   435  			// case.
   436  			select {
   437  			case <-state.AbandonCh():
   438  			default:
   439  				goto RUN_QUERY
   440  			}
   441  		}
   442  	}
   443  	return err
   444  }
   445  
   446  // setQueryMeta is used to populate the QueryMeta data for an RPC call
   447  func (s *Server) setQueryMeta(m *structs.QueryMeta) {
   448  	if s.IsLeader() {
   449  		m.LastContact = 0
   450  		m.KnownLeader = true
   451  	} else {
   452  		m.LastContact = time.Since(s.raft.LastContact())
   453  		m.KnownLeader = (s.raft.Leader() != "")
   454  	}
   455  }
   456  
   457  // consistentRead is used to ensure we do not perform a stale
   458  // read. This is done by verifying leadership before the read.
   459  func (s *Server) consistentRead() error {
   460  	defer metrics.MeasureSince([]string{"rpc", "consistentRead"}, time.Now())
   461  	future := s.raft.VerifyLeader()
   462  	if err := future.Error(); err != nil {
   463  		return err //fail fast if leader verification fails
   464  	}
   465  	// poll consistent read readiness, wait for up to RPCHoldTimeout milliseconds
   466  	if s.isReadyForConsistentReads() {
   467  		return nil
   468  	}
   469  	jitter := lib.RandomStagger(s.config.RPCHoldTimeout / jitterFraction)
   470  	deadline := time.Now().Add(s.config.RPCHoldTimeout)
   471  
   472  	for time.Now().Before(deadline) {
   473  
   474  		select {
   475  		case <-time.After(jitter):
   476  			// Drop through and check before we loop again.
   477  
   478  		case <-s.shutdownCh:
   479  			return fmt.Errorf("shutdown waiting for leader")
   480  		}
   481  
   482  		if s.isReadyForConsistentReads() {
   483  			return nil
   484  		}
   485  	}
   486  
   487  	return structs.ErrNotReadyForConsistentReads
   488  }