github.imxd.top/hashicorp/consul@v1.4.5/agent/consul/snapshot_endpoint.go (about)

     1  // The snapshot endpoint is a special non-RPC endpoint that supports streaming
     2  // for taking and restoring snapshots for disaster recovery. This gets wired
     3  // directly into Consul's stream handler, and a new TCP connection is made for
     4  // each request.
     5  //
     6  // This also includes a SnapshotRPC() function, which acts as a lightweight
     7  // client that knows the details of the stream protocol.
     8  package consul
     9  
    10  import (
    11  	"bytes"
    12  	"errors"
    13  	"fmt"
    14  	"io"
    15  	"io/ioutil"
    16  	"net"
    17  	"time"
    18  
    19  	"github.com/hashicorp/consul/acl"
    20  	"github.com/hashicorp/consul/agent/pool"
    21  	"github.com/hashicorp/consul/agent/structs"
    22  	"github.com/hashicorp/consul/snapshot"
    23  	"github.com/hashicorp/go-msgpack/codec"
    24  )
    25  
    26  // dispatchSnapshotRequest takes an incoming request structure with possibly some
    27  // streaming data (for a restore) and returns possibly some streaming data (for
    28  // a snapshot save). We can't use the normal RPC mechanism in a streaming manner
    29  // like this, so we have to dispatch these by hand.
    30  func (s *Server) dispatchSnapshotRequest(args *structs.SnapshotRequest, in io.Reader,
    31  	reply *structs.SnapshotResponse) (io.ReadCloser, error) {
    32  
    33  	// Perform DC forwarding.
    34  	if dc := args.Datacenter; dc != s.config.Datacenter {
    35  		manager, server, ok := s.router.FindRoute(dc)
    36  		if !ok {
    37  			return nil, structs.ErrNoDCPath
    38  		}
    39  
    40  		snap, err := SnapshotRPC(s.connPool, dc, server.Addr, server.UseTLS, args, in, reply)
    41  		if err != nil {
    42  			manager.NotifyFailedServer(server)
    43  			return nil, err
    44  		}
    45  
    46  		return snap, nil
    47  	}
    48  
    49  	// Perform leader forwarding if required.
    50  	if !args.AllowStale {
    51  		if isLeader, server := s.getLeader(); !isLeader {
    52  			if server == nil {
    53  				return nil, structs.ErrNoLeader
    54  			}
    55  			return SnapshotRPC(s.connPool, args.Datacenter, server.Addr, server.UseTLS, args, in, reply)
    56  		}
    57  	}
    58  
    59  	// Verify token is allowed to operate on snapshots. There's only a
    60  	// single ACL sense here (not read and write) since reading gets you
    61  	// all the ACLs and you could escalate from there.
    62  	if rule, err := s.ResolveToken(args.Token); err != nil {
    63  		return nil, err
    64  	} else if rule != nil && !rule.Snapshot() {
    65  		return nil, acl.ErrPermissionDenied
    66  	}
    67  
    68  	// Dispatch the operation.
    69  	switch args.Op {
    70  	case structs.SnapshotSave:
    71  		if !args.AllowStale {
    72  			if err := s.consistentRead(); err != nil {
    73  				return nil, err
    74  			}
    75  		}
    76  
    77  		// Set the metadata here before we do anything; this should always be
    78  		// pessimistic if we get more data while the snapshot is being taken.
    79  		s.setQueryMeta(&reply.QueryMeta)
    80  
    81  		// Take the snapshot and capture the index.
    82  		snap, err := snapshot.New(s.logger, s.raft)
    83  		reply.Index = snap.Index()
    84  		return snap, err
    85  
    86  	case structs.SnapshotRestore:
    87  		if args.AllowStale {
    88  			return nil, fmt.Errorf("stale not allowed for restore")
    89  		}
    90  
    91  		// Restore the snapshot.
    92  		if err := snapshot.Restore(s.logger, in, s.raft); err != nil {
    93  			return nil, err
    94  		}
    95  
    96  		// Run a barrier so we are sure that our FSM is caught up with
    97  		// any snapshot restore details (it's also part of Raft's restore
    98  		// process but we don't want to depend on that detail for this to
    99  		// be correct). Once that works, we can redo the leader actions
   100  		// so our leader-maintained state will be up to date.
   101  		barrier := s.raft.Barrier(0)
   102  		if err := barrier.Error(); err != nil {
   103  			return nil, err
   104  		}
   105  
   106  		// This'll be used for feedback from the leader loop.
   107  		errCh := make(chan error, 1)
   108  		timeoutCh := time.After(time.Minute)
   109  
   110  		select {
   111  		// Tell the leader loop to reassert leader actions since we just
   112  		// replaced the state store contents.
   113  		case s.reassertLeaderCh <- errCh:
   114  
   115  		// We might have lost leadership while waiting to kick the loop.
   116  		case <-timeoutCh:
   117  			return nil, fmt.Errorf("timed out waiting to re-run leader actions")
   118  
   119  		// Make sure we don't get stuck during shutdown
   120  		case <-s.shutdownCh:
   121  		}
   122  
   123  		select {
   124  		// Wait for the leader loop to finish up.
   125  		case err := <-errCh:
   126  			if err != nil {
   127  				return nil, err
   128  			}
   129  
   130  		// We might have lost leadership while the loop was doing its
   131  		// thing.
   132  		case <-timeoutCh:
   133  			return nil, fmt.Errorf("timed out waiting for re-run of leader actions")
   134  
   135  		// Make sure we don't get stuck during shutdown
   136  		case <-s.shutdownCh:
   137  		}
   138  
   139  		// Give the caller back an empty reader since there's nothing to
   140  		// stream back.
   141  		return ioutil.NopCloser(bytes.NewReader([]byte(""))), nil
   142  
   143  	default:
   144  		return nil, fmt.Errorf("unrecognized snapshot op %q", args.Op)
   145  	}
   146  }
   147  
   148  // handleSnapshotRequest reads the request from the conn and dispatches it. This
   149  // will be called from a goroutine after an incoming stream is determined to be
   150  // a snapshot request.
   151  func (s *Server) handleSnapshotRequest(conn net.Conn) error {
   152  	var args structs.SnapshotRequest
   153  	dec := codec.NewDecoder(conn, &codec.MsgpackHandle{})
   154  	if err := dec.Decode(&args); err != nil {
   155  		return fmt.Errorf("failed to decode request: %v", err)
   156  	}
   157  
   158  	var reply structs.SnapshotResponse
   159  	snap, err := s.dispatchSnapshotRequest(&args, conn, &reply)
   160  	if err != nil {
   161  		reply.Error = err.Error()
   162  		goto RESPOND
   163  	}
   164  	defer func() {
   165  		if err := snap.Close(); err != nil {
   166  			s.logger.Printf("[ERR] consul: Failed to close snapshot: %v", err)
   167  		}
   168  	}()
   169  
   170  RESPOND:
   171  	enc := codec.NewEncoder(conn, &codec.MsgpackHandle{})
   172  	if err := enc.Encode(&reply); err != nil {
   173  		return fmt.Errorf("failed to encode response: %v", err)
   174  	}
   175  	if snap != nil {
   176  		if _, err := io.Copy(conn, snap); err != nil {
   177  			return fmt.Errorf("failed to stream snapshot: %v", err)
   178  		}
   179  	}
   180  
   181  	return nil
   182  }
   183  
   184  // SnapshotRPC is a streaming client function for performing a snapshot RPC
   185  // request to a remote server. It will create a fresh connection for each
   186  // request, send the request header, and then stream in any data from the
   187  // reader (for a restore). It will then parse the received response header, and
   188  // if there's no error will return an io.ReadCloser (that you must close) with
   189  // the streaming output (for a snapshot). If the reply contains an error, this
   190  // will always return an error as well, so you don't need to check the error
   191  // inside the filled-in reply.
   192  func SnapshotRPC(connPool *pool.ConnPool, dc string, addr net.Addr, useTLS bool,
   193  	args *structs.SnapshotRequest, in io.Reader, reply *structs.SnapshotResponse) (io.ReadCloser, error) {
   194  
   195  	conn, hc, err := connPool.DialTimeout(dc, addr, 10*time.Second, useTLS)
   196  	if err != nil {
   197  		return nil, err
   198  	}
   199  
   200  	// keep will disarm the defer on success if we are returning the caller
   201  	// our connection to stream the output.
   202  	var keep bool
   203  	defer func() {
   204  		if !keep {
   205  			conn.Close()
   206  		}
   207  	}()
   208  
   209  	// Write the snapshot RPC byte to set the mode, then perform the
   210  	// request.
   211  	if _, err := conn.Write([]byte{byte(pool.RPCSnapshot)}); err != nil {
   212  		return nil, fmt.Errorf("failed to write stream type: %v", err)
   213  	}
   214  
   215  	// Push the header encoded as msgpack, then stream the input.
   216  	enc := codec.NewEncoder(conn, &codec.MsgpackHandle{})
   217  	if err := enc.Encode(&args); err != nil {
   218  		return nil, fmt.Errorf("failed to encode request: %v", err)
   219  	}
   220  	if _, err := io.Copy(conn, in); err != nil {
   221  		return nil, fmt.Errorf("failed to copy snapshot in: %v", err)
   222  	}
   223  
   224  	// Our RPC protocol requires support for a half-close in order to signal
   225  	// the other side that they are done reading the stream, since we don't
   226  	// know the size in advance. This saves us from having to buffer just to
   227  	// calculate the size.
   228  	if hc != nil {
   229  		if err := hc.CloseWrite(); err != nil {
   230  			return nil, fmt.Errorf("failed to half close snapshot connection: %v", err)
   231  		}
   232  	} else {
   233  		return nil, fmt.Errorf("snapshot connection requires half-close support")
   234  	}
   235  
   236  	// Pull the header decoded as msgpack. The caller can continue to read
   237  	// the conn to stream the remaining data.
   238  	dec := codec.NewDecoder(conn, &codec.MsgpackHandle{})
   239  	if err := dec.Decode(reply); err != nil {
   240  		return nil, fmt.Errorf("failed to decode response: %v", err)
   241  	}
   242  	if reply.Error != "" {
   243  		return nil, errors.New(reply.Error)
   244  	}
   245  
   246  	keep = true
   247  	return conn, nil
   248  }