github.com/mutagen-io/mutagen@v0.18.0-rc1/pkg/synchronization/endpoint/remote/server.go (about)

     1  package remote
     2  
     3  import (
     4  	"bufio"
     5  	"context"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  
    10  	"google.golang.org/protobuf/proto"
    11  
    12  	"github.com/mutagen-io/mutagen/pkg/encoding"
    13  	"github.com/mutagen-io/mutagen/pkg/filesystem"
    14  	"github.com/mutagen-io/mutagen/pkg/logging"
    15  	streampkg "github.com/mutagen-io/mutagen/pkg/stream"
    16  	"github.com/mutagen-io/mutagen/pkg/synchronization"
    17  	"github.com/mutagen-io/mutagen/pkg/synchronization/compression"
    18  	"github.com/mutagen-io/mutagen/pkg/synchronization/core"
    19  	"github.com/mutagen-io/mutagen/pkg/synchronization/endpoint/local"
    20  	"github.com/mutagen-io/mutagen/pkg/synchronization/rsync"
    21  )
    22  
    23  // endpointServer wraps a local endpoint instances and dispatches requests to
    24  // this endpoint from an endpoint client.
    25  type endpointServer struct {
    26  	// endpoint is the underlying local endpoint.
    27  	endpoint synchronization.Endpoint
    28  	// flusher flushes the outbound control stream.
    29  	flusher streampkg.Flusher
    30  	// encoder is the control stream encoder.
    31  	encoder *encoding.ProtobufEncoder
    32  	// decoder is the control stream decoder.
    33  	decoder *encoding.ProtobufDecoder
    34  }
    35  
    36  // ServeEndpoint creates and serves a endpoint server on the specified stream.
    37  // It enforces that the provided stream is closed by the time this function
    38  // returns, regardless of failure. The provided stream must unblock read and
    39  // write operations when closed.
    40  func ServeEndpoint(logger *logging.Logger, stream io.ReadWriteCloser) error {
    41  	// Perform the compression handshake.
    42  	compressionAlgorithm, err := compression.ServerHandshake(stream)
    43  	if err != nil {
    44  		stream.Close()
    45  		return fmt.Errorf("compression handshake failed: %w", err)
    46  	}
    47  
    48  	// Set up inbound buffering and decompression. While the decompressor does
    49  	// have some internal buffering, we need the inbound stream to support
    50  	// io.ByteReader for our Protocol Buffer decoding, so we add a bufio.Reader
    51  	// around it with additional buffering.
    52  	compressedInbound := bufio.NewReaderSize(stream, controlStreamCompressedBufferSize)
    53  	decompressor := compressionAlgorithm.Decompress(compressedInbound)
    54  	inbound := bufio.NewReaderSize(decompressor, controlStreamUncompressedBufferSize)
    55  
    56  	// Set up outbound buffering and compression.
    57  	compressedOutbound := bufio.NewWriterSize(stream, controlStreamCompressedBufferSize)
    58  	compressor := compressionAlgorithm.Compress(compressedOutbound)
    59  	outbound := bufio.NewWriterSize(compressor, controlStreamUncompressedBufferSize)
    60  
    61  	// Create a mechanism to flush the outbound pipeline.
    62  	flusher := streampkg.NewMultiFlusher(outbound, compressor, compressedOutbound)
    63  
    64  	// Create a closer for the control stream and compression resources and
    65  	// defer its invocation.
    66  	closer := streampkg.NewMultiCloser(
    67  		streampkg.NewFlushCloser(outbound),
    68  		compressor,
    69  		streampkg.NewFlushCloser(compressedOutbound),
    70  		stream,
    71  		decompressor,
    72  	)
    73  	defer closer.Close()
    74  
    75  	// Create an encoder and a decoder for Protocol Buffers messages.
    76  	encoder := encoding.NewProtobufEncoder(outbound)
    77  	decoder := encoding.NewProtobufDecoder(inbound)
    78  
    79  	// Receive the initialize request. If this fails, then send a failure
    80  	// response (even though the pipe is probably broken) and abort.
    81  	request := &InitializeSynchronizationRequest{}
    82  	if err := decoder.Decode(request); err != nil {
    83  		err = fmt.Errorf("unable to receive initialize request: %w", err)
    84  		encoder.Encode(&InitializeSynchronizationResponse{Error: err.Error()})
    85  		flusher.Flush()
    86  		return err
    87  	}
    88  
    89  	// Ensure that the initialization request is valid.
    90  	if err := request.ensureValid(); err != nil {
    91  		err = fmt.Errorf("invalid initialize request: %w", err)
    92  		encoder.Encode(&InitializeSynchronizationResponse{Error: err.Error()})
    93  		flusher.Flush()
    94  		return err
    95  	}
    96  
    97  	// Expand and normalize the root path.
    98  	if r, err := filesystem.Normalize(request.Root); err != nil {
    99  		err = fmt.Errorf("unable to normalize synchronization root: %w", err)
   100  		encoder.Encode(&InitializeSynchronizationResponse{Error: err.Error()})
   101  		flusher.Flush()
   102  		return err
   103  	} else {
   104  		request.Root = r
   105  	}
   106  
   107  	// Create the underlying endpoint. If it fails to create, then send a
   108  	// failure response and abort. If it succeeds, then defer its closure.
   109  	endpoint, err := local.NewEndpoint(
   110  		logger,
   111  		request.Root,
   112  		request.Session,
   113  		request.Version,
   114  		request.Configuration,
   115  		request.Alpha,
   116  	)
   117  	if err != nil {
   118  		err = fmt.Errorf("unable to create underlying endpoint: %w", err)
   119  		encoder.Encode(&InitializeSynchronizationResponse{Error: err.Error()})
   120  		flusher.Flush()
   121  		return err
   122  	}
   123  	defer endpoint.Shutdown()
   124  
   125  	// Send a successful initialize response.
   126  	if err = encoder.Encode(&InitializeSynchronizationResponse{}); err != nil {
   127  		return fmt.Errorf("unable to encode initialize response: %w", err)
   128  	} else if err = flusher.Flush(); err != nil {
   129  		return fmt.Errorf("unable to transmit initialize response: %w", err)
   130  	}
   131  
   132  	// Create the server.
   133  	server := &endpointServer{
   134  		endpoint: endpoint,
   135  		flusher:  flusher,
   136  		encoder:  encoder,
   137  		decoder:  decoder,
   138  	}
   139  
   140  	// Server until an error occurs.
   141  	return server.serve()
   142  }
   143  
   144  // encodeAndFlush encodes a Protocol Buffers message using the underlying
   145  // encoder and then flushes the control stream.
   146  func (s *endpointServer) encodeAndFlush(message proto.Message) error {
   147  	if err := s.encoder.Encode(message); err != nil {
   148  		return err
   149  	} else if err = s.flusher.Flush(); err != nil {
   150  		return fmt.Errorf("message transmission failed: %w", err)
   151  	}
   152  	return nil
   153  }
   154  
   155  // serve is the main request handling loop.
   156  func (s *endpointServer) serve() error {
   157  	// Keep a reusable endpoint request.
   158  	request := &EndpointRequest{}
   159  
   160  	// Receive and process control requests until there's an error.
   161  	for {
   162  		// Receive the next request.
   163  		*request = EndpointRequest{}
   164  		if err := s.decoder.Decode(request); err != nil {
   165  			return fmt.Errorf("unable to receive request: %w", err)
   166  		} else if err = request.ensureValid(); err != nil {
   167  			return fmt.Errorf("invalid endpoint request: %w", err)
   168  		}
   169  
   170  		// Handle the request based on type.
   171  		if request.Poll != nil {
   172  			if err := s.servePoll(request.Poll); err != nil {
   173  				return fmt.Errorf("unable to serve poll request: %w", err)
   174  			}
   175  		} else if request.Scan != nil {
   176  			if err := s.serveScan(request.Scan); err != nil {
   177  				return fmt.Errorf("unable to serve scan request: %w", err)
   178  			}
   179  		} else if request.Stage != nil {
   180  			if err := s.serveStage(request.Stage); err != nil {
   181  				return fmt.Errorf("unable to serve stage request: %w", err)
   182  			}
   183  		} else if request.Supply != nil {
   184  			if err := s.serveSupply(request.Supply); err != nil {
   185  				return fmt.Errorf("unable to serve supply request: %w", err)
   186  			}
   187  		} else if request.Transition != nil {
   188  			if err := s.serveTransition(request.Transition); err != nil {
   189  				return fmt.Errorf("unable to serve transition request: %w", err)
   190  			}
   191  		} else {
   192  			// TODO: Should we panic here? The request validation already
   193  			// ensures that one and only one message component is set, so we
   194  			// should never hit this condition.
   195  			return errors.New("invalid request")
   196  		}
   197  	}
   198  }
   199  
   200  // servePoll serves a poll request.
   201  func (s *endpointServer) servePoll(request *PollRequest) error {
   202  	// Ensure the request is valid.
   203  	if err := request.ensureValid(); err != nil {
   204  		return fmt.Errorf("invalid poll request: %w", err)
   205  	}
   206  
   207  	// Create a cancellable context for executing the poll.
   208  	ctx, cancel := context.WithCancel(context.Background())
   209  
   210  	// Start a Goroutine to watch for the completion request.
   211  	completionReceiveErrors := make(chan error, 1)
   212  	go func() {
   213  		request := &PollCompletionRequest{}
   214  		if err := s.decoder.Decode(request); err != nil {
   215  			completionReceiveErrors <- fmt.Errorf("unable to receive completion request: %w", err)
   216  		} else if err = request.ensureValid(); err != nil {
   217  			completionReceiveErrors <- fmt.Errorf("received invalid completion request: %w", err)
   218  		} else {
   219  			completionReceiveErrors <- nil
   220  		}
   221  	}()
   222  
   223  	// Start a Goroutine to execute the poll and send a response when done.
   224  	responseSendErrors := make(chan error, 1)
   225  	go func() {
   226  		// Perform polling and set up the response.
   227  		var response *PollResponse
   228  		if err := s.endpoint.Poll(ctx); err != nil {
   229  			response = &PollResponse{
   230  				Error: err.Error(),
   231  			}
   232  		} else {
   233  			response = &PollResponse{}
   234  		}
   235  
   236  		// Send te response.
   237  		if err := s.encodeAndFlush(response); err != nil {
   238  			responseSendErrors <- fmt.Errorf("unable to transmit response: %w", err)
   239  		} else {
   240  			responseSendErrors <- nil
   241  		}
   242  	}()
   243  
   244  	// Wait for both a completion request to be received and a response to be
   245  	// sent. Both of these will occur, though their order is not known. If the
   246  	// completion request is received first, then we cancel the subcontext to
   247  	// preempt the scan and force transmission of a response. If the response is
   248  	// sent first, then we know the completion request is on its way. In this
   249  	// case, we still cancel the subcontext we created as required by the
   250  	// context package to avoid leaking resources.
   251  	var responseSendErr, completionReceiveErr error
   252  	select {
   253  	case completionReceiveErr = <-completionReceiveErrors:
   254  		cancel()
   255  		responseSendErr = <-responseSendErrors
   256  	case responseSendErr = <-responseSendErrors:
   257  		cancel()
   258  		completionReceiveErr = <-completionReceiveErrors
   259  	}
   260  
   261  	// Check for errors.
   262  	if responseSendErr != nil {
   263  		return responseSendErr
   264  	} else if completionReceiveErr != nil {
   265  		return completionReceiveErr
   266  	}
   267  
   268  	// Success.
   269  	return nil
   270  }
   271  
   272  // serveScan serves a scan request.
   273  func (s *endpointServer) serveScan(request *ScanRequest) error {
   274  	// Ensure the request is valid.
   275  	if err := request.ensureValid(); err != nil {
   276  		return fmt.Errorf("invalid scan request: %w", err)
   277  	}
   278  
   279  	// Create a cancellable context for executing the scan.
   280  	ctx, cancel := context.WithCancel(context.Background())
   281  
   282  	// Start a Goroutine to watch for the completion request.
   283  	completionReceiveErrors := make(chan error, 1)
   284  	go func() {
   285  		request := &ScanCompletionRequest{}
   286  		if err := s.decoder.Decode(request); err != nil {
   287  			completionReceiveErrors <- fmt.Errorf("unable to receive completion request: %w", err)
   288  		} else if err = request.ensureValid(); err != nil {
   289  			completionReceiveErrors <- fmt.Errorf("received invalid completion request: %w", err)
   290  		} else {
   291  			completionReceiveErrors <- nil
   292  		}
   293  	}()
   294  
   295  	// Start a Goroutine to execute the scan and send a response when done.
   296  	responseSendErrors := make(chan error, 1)
   297  	go func() {
   298  		// Configure Protocol Buffers marshaling to be deterministic.
   299  		marshaling := proto.MarshalOptions{Deterministic: true}
   300  
   301  		// Create an rsync engine.
   302  		engine := rsync.NewEngine()
   303  
   304  		// Perform a scan and set up the response.
   305  		var response *ScanResponse
   306  		snapshot, err, tryAgain := s.endpoint.Scan(ctx, nil, request.Full)
   307  		if err != nil {
   308  			response = &ScanResponse{
   309  				Error:    err.Error(),
   310  				TryAgain: tryAgain,
   311  			}
   312  		} else if snapshotBytes, err := marshaling.Marshal(snapshot); err != nil {
   313  			response = &ScanResponse{
   314  				Error: fmt.Errorf("unable to marshal snapshot: %w", err).Error(),
   315  			}
   316  		} else {
   317  			response = &ScanResponse{
   318  				SnapshotDelta: engine.DeltifyBytes(
   319  					snapshotBytes,
   320  					request.BaselineSnapshotSignature,
   321  					0,
   322  				),
   323  			}
   324  		}
   325  
   326  		// Send the response.
   327  		if err := s.encodeAndFlush(response); err != nil {
   328  			responseSendErrors <- fmt.Errorf("unable to transmit response: %w", err)
   329  		} else {
   330  			responseSendErrors <- nil
   331  		}
   332  	}()
   333  
   334  	// Wait for both a completion request to be received and a response to be
   335  	// sent. Both of these will occur, though their order is not known. If the
   336  	// completion request is received first, then we cancel the subcontext to
   337  	// preempt the scan and force transmission of a response. If the response is
   338  	// sent first, then we know the completion request is on its way. In this
   339  	// case, we still cancel the subcontext we created as required by the
   340  	// context package to avoid leaking resources.
   341  	var responseSendErr, completionReceiveErr error
   342  	select {
   343  	case completionReceiveErr = <-completionReceiveErrors:
   344  		cancel()
   345  		responseSendErr = <-responseSendErrors
   346  	case responseSendErr = <-responseSendErrors:
   347  		cancel()
   348  		completionReceiveErr = <-completionReceiveErrors
   349  	}
   350  
   351  	// Check for errors.
   352  	if responseSendErr != nil {
   353  		return responseSendErr
   354  	} else if completionReceiveErr != nil {
   355  		return completionReceiveErr
   356  	}
   357  
   358  	// Success.
   359  	return nil
   360  }
   361  
   362  // serveStage serves a stage request.
   363  func (s *endpointServer) serveStage(request *StageRequest) error {
   364  	// Ensure the request is valid.
   365  	if err := request.ensureValid(); err != nil {
   366  		return fmt.Errorf("invalid stage request: %w", err)
   367  	}
   368  
   369  	// Begin staging.
   370  	paths, signatures, receiver, err := s.endpoint.Stage(request.Paths, request.Digests)
   371  	if err != nil {
   372  		s.encodeAndFlush(&StageResponse{Error: err.Error()})
   373  		return fmt.Errorf("unable to begin staging: %w", err)
   374  	}
   375  
   376  	// If all of the requested paths are required, then we'll signal this in the
   377  	// response by using an empty path list. This is an important heuristic to
   378  	// reduce response size on initial staging.
   379  	responsePaths := paths
   380  	if len(responsePaths) == len(request.Paths) {
   381  		responsePaths = nil
   382  	}
   383  
   384  	// Send the response.
   385  	response := &StageResponse{
   386  		Paths:      responsePaths,
   387  		Signatures: signatures,
   388  	}
   389  	if err = s.encodeAndFlush(response); err != nil {
   390  		return fmt.Errorf("unable to send stage response: %w", err)
   391  	}
   392  
   393  	// If there weren't any paths requiring staging, then we're done.
   394  	if len(paths) == 0 {
   395  		return nil
   396  	}
   397  
   398  	// The remote side of the connection should now forward rsync operations, so
   399  	// we need to decode and forward them to the receiver. If this operation
   400  	// completes successfully, staging is complete and successful.
   401  	decoder := &protobufRsyncDecoder{decoder: s.decoder}
   402  	if err = rsync.DecodeToReceiver(decoder, uint64(len(paths)), receiver); err != nil {
   403  		return fmt.Errorf("unable to decode and forward rsync operations: %w", err)
   404  	}
   405  
   406  	// Success.
   407  	return nil
   408  }
   409  
   410  // serveSupply serves a supply request.
   411  func (s *endpointServer) serveSupply(request *SupplyRequest) error {
   412  	// Ensure the request is valid.
   413  	if err := request.ensureValid(); err != nil {
   414  		return fmt.Errorf("invalid supply request: %w", err)
   415  	}
   416  
   417  	// Create an encoding receiver to transmit rsync operations to the remote.
   418  	encoder := &protobufRsyncEncoder{encoder: s.encoder, flusher: s.flusher}
   419  	receiver := rsync.NewEncodingReceiver(encoder)
   420  
   421  	// Perform supplying.
   422  	if err := s.endpoint.Supply(request.Paths, request.Signatures, receiver); err != nil {
   423  		return fmt.Errorf("unable to perform supplying: %w", err)
   424  	}
   425  
   426  	// Success.
   427  	return nil
   428  }
   429  
   430  // serveTransition serves a transition request.
   431  func (s *endpointServer) serveTransition(request *TransitionRequest) error {
   432  	// Ensure the request is valid.
   433  	if err := request.ensureValid(); err != nil {
   434  		return fmt.Errorf("invalid transition request: %w", err)
   435  	}
   436  
   437  	// Create a cancellable context for executing the transition.
   438  	ctx, cancel := context.WithCancel(context.Background())
   439  
   440  	// Start a Goroutine to watch for the completion request.
   441  	completionReceiveErrors := make(chan error, 1)
   442  	go func() {
   443  		request := &TransitionCompletionRequest{}
   444  		if err := s.decoder.Decode(request); err != nil {
   445  			completionReceiveErrors <- fmt.Errorf("unable to receive completion request: %w", err)
   446  		} else if err = request.ensureValid(); err != nil {
   447  			completionReceiveErrors <- fmt.Errorf("received invalid completion request: %w", err)
   448  		} else {
   449  			completionReceiveErrors <- nil
   450  		}
   451  	}()
   452  
   453  	// Start a Goroutine to execute the transition and send a response when
   454  	// done.
   455  	responseSendErrors := make(chan error, 1)
   456  	go func() {
   457  		// Perform the transition and set up the response.
   458  		var response *TransitionResponse
   459  		results, problems, stagerMissingFiles, err := s.endpoint.Transition(ctx, request.Transitions)
   460  		if err != nil {
   461  			response = &TransitionResponse{
   462  				Error: err.Error(),
   463  			}
   464  		} else {
   465  			// HACK: Wrap the results in Archives since Protocol Buffers can't
   466  			// encode nil pointers in the result array.
   467  			wrappedResults := make([]*core.Archive, len(results))
   468  			for r, result := range results {
   469  				wrappedResults[r] = &core.Archive{Content: result}
   470  			}
   471  			response = &TransitionResponse{
   472  				Results:            wrappedResults,
   473  				Problems:           problems,
   474  				StagerMissingFiles: stagerMissingFiles,
   475  			}
   476  		}
   477  
   478  		// Send the response.
   479  		if err := s.encodeAndFlush(response); err != nil {
   480  			responseSendErrors <- fmt.Errorf("unable to transmit response: %w", err)
   481  		} else {
   482  			responseSendErrors <- nil
   483  		}
   484  	}()
   485  
   486  	// Wait for both a completion request to be received and a response to be
   487  	// sent. Both of these will occur, though their order is not known. If the
   488  	// completion request is received first, then we cancel the subcontext to
   489  	// preempt the transition and force transmission of a response. If the
   490  	// response is sent first, then we know the completion request is on its
   491  	// way. In this case, we still cancel the subcontext we created as required
   492  	// by the context package to avoid leaking resources.
   493  	var responseSendErr, completionReceiveErr error
   494  	select {
   495  	case completionReceiveErr = <-completionReceiveErrors:
   496  		cancel()
   497  		responseSendErr = <-responseSendErrors
   498  	case responseSendErr = <-responseSendErrors:
   499  		cancel()
   500  		completionReceiveErr = <-completionReceiveErrors
   501  	}
   502  
   503  	// Check for errors.
   504  	if responseSendErr != nil {
   505  		return responseSendErr
   506  	} else if completionReceiveErr != nil {
   507  		return completionReceiveErr
   508  	}
   509  
   510  	// Success.
   511  	return nil
   512  }