github.com/MetalBlockchain/metalgo@v1.11.9/x/sync/network_server.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package sync
     5  
     6  import (
     7  	"bytes"
     8  	"context"
     9  	"errors"
    10  	"fmt"
    11  	"time"
    12  
    13  	"go.uber.org/zap"
    14  	"google.golang.org/grpc/codes"
    15  	"google.golang.org/grpc/status"
    16  	"google.golang.org/protobuf/proto"
    17  
    18  	"github.com/MetalBlockchain/metalgo/ids"
    19  	"github.com/MetalBlockchain/metalgo/snow/engine/common"
    20  	"github.com/MetalBlockchain/metalgo/utils/constants"
    21  	"github.com/MetalBlockchain/metalgo/utils/hashing"
    22  	"github.com/MetalBlockchain/metalgo/utils/logging"
    23  	"github.com/MetalBlockchain/metalgo/utils/maybe"
    24  	"github.com/MetalBlockchain/metalgo/utils/units"
    25  	"github.com/MetalBlockchain/metalgo/x/merkledb"
    26  
    27  	pb "github.com/MetalBlockchain/metalgo/proto/pb/sync"
    28  )
    29  
    30  const (
    31  	// Maximum number of key-value pairs to return in a proof.
    32  	// This overrides any other Limit specified in a RangeProofRequest
    33  	// or ChangeProofRequest if the given Limit is greater.
    34  	maxKeyValuesLimit = 2048
    35  	// Estimated max overhead, in bytes, of putting a proof into a message.
    36  	// We use this to ensure that the proof we generate is not too large to fit in a message.
    37  	// TODO: refine this estimate. This is almost certainly a large overestimate.
    38  	estimatedMessageOverhead = 4 * units.KiB
    39  	maxByteSizeLimit         = constants.DefaultMaxMessageSize - estimatedMessageOverhead
    40  )
    41  
    42  var (
    43  	ErrMinProofSizeIsTooLarge = errors.New("cannot generate any proof within the requested limit")
    44  
    45  	errInvalidBytesLimit    = errors.New("bytes limit must be greater than 0")
    46  	errInvalidKeyLimit      = errors.New("key limit must be greater than 0")
    47  	errInvalidStartRootHash = fmt.Errorf("start root hash must have length %d", hashing.HashLen)
    48  	errInvalidEndRootHash   = fmt.Errorf("end root hash must have length %d", hashing.HashLen)
    49  	errInvalidStartKey      = errors.New("start key is Nothing but has value")
    50  	errInvalidEndKey        = errors.New("end key is Nothing but has value")
    51  	errInvalidBounds        = errors.New("start key is greater than end key")
    52  	errInvalidRootHash      = fmt.Errorf("root hash must have length %d", hashing.HashLen)
    53  )
    54  
    55  type NetworkServer struct {
    56  	appSender common.AppSender // Used to respond to peer requests via AppResponse.
    57  	db        DB
    58  	log       logging.Logger
    59  }
    60  
    61  func NewNetworkServer(appSender common.AppSender, db DB, log logging.Logger) *NetworkServer {
    62  	return &NetworkServer{
    63  		appSender: appSender,
    64  		db:        db,
    65  		log:       log,
    66  	}
    67  }
    68  
    69  // AppRequest is called by avalanchego -> VM when there is an incoming AppRequest from a peer.
    70  // Returns a non-nil error iff we fail to send an app message. This is a fatal error.
    71  // Sends a response back to the sender if length of response returned by the handler > 0.
    72  func (s *NetworkServer) AppRequest(
    73  	ctx context.Context,
    74  	nodeID ids.NodeID,
    75  	requestID uint32,
    76  	deadline time.Time,
    77  	request []byte,
    78  ) error {
    79  	var req pb.Request
    80  	if err := proto.Unmarshal(request, &req); err != nil {
    81  		s.log.Debug(
    82  			"failed to unmarshal AppRequest",
    83  			zap.Stringer("nodeID", nodeID),
    84  			zap.Uint32("requestID", requestID),
    85  			zap.Int("requestLen", len(request)),
    86  			zap.Error(err),
    87  		)
    88  		return nil
    89  	}
    90  	s.log.Debug(
    91  		"processing AppRequest from node",
    92  		zap.Stringer("nodeID", nodeID),
    93  		zap.Uint32("requestID", requestID),
    94  	)
    95  
    96  	// bufferedDeadline is half the time till actual deadline so that the message has a
    97  	// reasonable chance of completing its processing and sending the response to the peer.
    98  	timeTillDeadline := time.Until(deadline)
    99  	bufferedDeadline := time.Now().Add(timeTillDeadline / 2)
   100  
   101  	// check if we have enough time to handle this request.
   102  	// TODO danlaine: Do we need this? Why?
   103  	if time.Until(bufferedDeadline) < minRequestHandlingDuration {
   104  		// Drop the request if we already missed the deadline to respond.
   105  		s.log.Info(
   106  			"deadline to process AppRequest has expired, skipping",
   107  			zap.Stringer("nodeID", nodeID),
   108  			zap.Uint32("requestID", requestID),
   109  		)
   110  		return nil
   111  	}
   112  
   113  	ctx, cancel := context.WithDeadline(ctx, bufferedDeadline)
   114  	defer cancel()
   115  
   116  	var err error
   117  	switch req := req.GetMessage().(type) {
   118  	case *pb.Request_ChangeProofRequest:
   119  		err = s.HandleChangeProofRequest(ctx, nodeID, requestID, req.ChangeProofRequest)
   120  	case *pb.Request_RangeProofRequest:
   121  		err = s.HandleRangeProofRequest(ctx, nodeID, requestID, req.RangeProofRequest)
   122  	default:
   123  		s.log.Debug(
   124  			"unknown AppRequest type",
   125  			zap.Stringer("nodeID", nodeID),
   126  			zap.Uint32("requestID", requestID),
   127  			zap.Int("requestLen", len(request)),
   128  			zap.String("requestType", fmt.Sprintf("%T", req)),
   129  		)
   130  		return nil
   131  	}
   132  
   133  	if err != nil {
   134  		if errors.Is(err, errAppSendFailed) {
   135  			return err
   136  		}
   137  
   138  		if !isTimeout(err) {
   139  			// log unexpected errors instead of returning them, since they are fatal.
   140  			s.log.Warn(
   141  				"unexpected error handling AppRequest",
   142  				zap.Stringer("nodeID", nodeID),
   143  				zap.Uint32("requestID", requestID),
   144  				zap.Error(err),
   145  			)
   146  		}
   147  	}
   148  	return nil
   149  }
   150  
   151  func maybeBytesToMaybe(mb *pb.MaybeBytes) maybe.Maybe[[]byte] {
   152  	if mb != nil && !mb.IsNothing {
   153  		return maybe.Some(mb.Value)
   154  	}
   155  	return maybe.Nothing[[]byte]()
   156  }
   157  
   158  // Generates a change proof and sends it to [nodeID].
   159  // If [errAppSendFailed] is returned, this should be considered fatal.
   160  func (s *NetworkServer) HandleChangeProofRequest(
   161  	ctx context.Context,
   162  	nodeID ids.NodeID,
   163  	requestID uint32,
   164  	req *pb.SyncGetChangeProofRequest,
   165  ) error {
   166  	if err := validateChangeProofRequest(req); err != nil {
   167  		s.log.Debug(
   168  			"dropping invalid change proof request",
   169  			zap.Stringer("nodeID", nodeID),
   170  			zap.Uint32("requestID", requestID),
   171  			zap.Stringer("req", req),
   172  			zap.Error(err),
   173  		)
   174  		return nil // dropping request
   175  	}
   176  
   177  	// override limits if they exceed caps
   178  	var (
   179  		keyLimit   = min(req.KeyLimit, maxKeyValuesLimit)
   180  		bytesLimit = min(int(req.BytesLimit), maxByteSizeLimit)
   181  		start      = maybeBytesToMaybe(req.StartKey)
   182  		end        = maybeBytesToMaybe(req.EndKey)
   183  	)
   184  
   185  	startRoot, err := ids.ToID(req.StartRootHash)
   186  	if err != nil {
   187  		return err
   188  	}
   189  
   190  	endRoot, err := ids.ToID(req.EndRootHash)
   191  	if err != nil {
   192  		return err
   193  	}
   194  
   195  	for keyLimit > 0 {
   196  		changeProof, err := s.db.GetChangeProof(ctx, startRoot, endRoot, start, end, int(keyLimit))
   197  		if err != nil {
   198  			if !errors.Is(err, merkledb.ErrInsufficientHistory) {
   199  				// We should only fail to get a change proof if we have insufficient history.
   200  				// Other errors are unexpected.
   201  				return err
   202  			}
   203  			if errors.Is(err, merkledb.ErrNoEndRoot) {
   204  				// [s.db] doesn't have [endRoot] in its history.
   205  				// We can't generate a change/range proof. Drop this request.
   206  				return nil
   207  			}
   208  
   209  			// [s.db] doesn't have sufficient history to generate change proof.
   210  			// Generate a range proof for the end root ID instead.
   211  			proofBytes, err := getRangeProof(
   212  				ctx,
   213  				s.db,
   214  				&pb.SyncGetRangeProofRequest{
   215  					RootHash:   req.EndRootHash,
   216  					StartKey:   req.StartKey,
   217  					EndKey:     req.EndKey,
   218  					KeyLimit:   req.KeyLimit,
   219  					BytesLimit: req.BytesLimit,
   220  				},
   221  				func(rangeProof *merkledb.RangeProof) ([]byte, error) {
   222  					return proto.Marshal(&pb.SyncGetChangeProofResponse{
   223  						Response: &pb.SyncGetChangeProofResponse_RangeProof{
   224  							RangeProof: rangeProof.ToProto(),
   225  						},
   226  					})
   227  				},
   228  			)
   229  			if err != nil {
   230  				return err
   231  			}
   232  
   233  			if err := s.appSender.SendAppResponse(ctx, nodeID, requestID, proofBytes); err != nil {
   234  				s.log.Fatal(
   235  					"failed to send app response",
   236  					zap.Stringer("nodeID", nodeID),
   237  					zap.Uint32("requestID", requestID),
   238  					zap.Int("responseLen", len(proofBytes)),
   239  					zap.Error(err),
   240  				)
   241  				return fmt.Errorf("%w: %w", errAppSendFailed, err)
   242  			}
   243  			return nil
   244  		}
   245  
   246  		// We generated a change proof. See if it's small enough.
   247  		proofBytes, err := proto.Marshal(&pb.SyncGetChangeProofResponse{
   248  			Response: &pb.SyncGetChangeProofResponse_ChangeProof{
   249  				ChangeProof: changeProof.ToProto(),
   250  			},
   251  		})
   252  		if err != nil {
   253  			return err
   254  		}
   255  
   256  		if len(proofBytes) < bytesLimit {
   257  			if err := s.appSender.SendAppResponse(ctx, nodeID, requestID, proofBytes); err != nil {
   258  				s.log.Fatal(
   259  					"failed to send app response",
   260  					zap.Stringer("nodeID", nodeID),
   261  					zap.Uint32("requestID", requestID),
   262  					zap.Int("responseLen", len(proofBytes)),
   263  					zap.Error(err),
   264  				)
   265  				return fmt.Errorf("%w: %w", errAppSendFailed, err)
   266  			}
   267  			return nil
   268  		}
   269  
   270  		// The proof was too large. Try to shrink it.
   271  		keyLimit = uint32(len(changeProof.KeyChanges)) / 2
   272  	}
   273  	return ErrMinProofSizeIsTooLarge
   274  }
   275  
   276  // Generates a range proof and sends it to [nodeID].
   277  // If [errAppSendFailed] is returned, this should be considered fatal.
   278  func (s *NetworkServer) HandleRangeProofRequest(
   279  	ctx context.Context,
   280  	nodeID ids.NodeID,
   281  	requestID uint32,
   282  	req *pb.SyncGetRangeProofRequest,
   283  ) error {
   284  	if err := validateRangeProofRequest(req); err != nil {
   285  		s.log.Debug(
   286  			"dropping invalid range proof request",
   287  			zap.Stringer("nodeID", nodeID),
   288  			zap.Uint32("requestID", requestID),
   289  			zap.Stringer("req", req),
   290  			zap.Error(err),
   291  		)
   292  		return nil // drop request
   293  	}
   294  
   295  	// override limits if they exceed caps
   296  	req.KeyLimit = min(req.KeyLimit, maxKeyValuesLimit)
   297  	req.BytesLimit = min(req.BytesLimit, maxByteSizeLimit)
   298  
   299  	proofBytes, err := getRangeProof(
   300  		ctx,
   301  		s.db,
   302  		req,
   303  		func(rangeProof *merkledb.RangeProof) ([]byte, error) {
   304  			return proto.Marshal(rangeProof.ToProto())
   305  		},
   306  	)
   307  	if err != nil {
   308  		return err
   309  	}
   310  	if err := s.appSender.SendAppResponse(ctx, nodeID, requestID, proofBytes); err != nil {
   311  		s.log.Fatal(
   312  			"failed to send app response",
   313  			zap.Stringer("nodeID", nodeID),
   314  			zap.Uint32("requestID", requestID),
   315  			zap.Int("responseLen", len(proofBytes)),
   316  			zap.Error(err),
   317  		)
   318  		return fmt.Errorf("%w: %w", errAppSendFailed, err)
   319  	}
   320  	return nil
   321  }
   322  
   323  // Get the range proof specified by [req].
   324  // If the generated proof is too large, the key limit is reduced
   325  // and the proof is regenerated. This process is repeated until
   326  // the proof is smaller than [req.BytesLimit].
   327  // When a sufficiently small proof is generated, returns it.
   328  // If no sufficiently small proof can be generated, returns [ErrMinProofSizeIsTooLarge].
   329  // TODO improve range proof generation so we don't need to iteratively
   330  // reduce the key limit.
   331  func getRangeProof(
   332  	ctx context.Context,
   333  	db DB,
   334  	req *pb.SyncGetRangeProofRequest,
   335  	marshalFunc func(*merkledb.RangeProof) ([]byte, error),
   336  ) ([]byte, error) {
   337  	root, err := ids.ToID(req.RootHash)
   338  	if err != nil {
   339  		return nil, err
   340  	}
   341  
   342  	keyLimit := int(req.KeyLimit)
   343  
   344  	for keyLimit > 0 {
   345  		rangeProof, err := db.GetRangeProofAtRoot(
   346  			ctx,
   347  			root,
   348  			maybeBytesToMaybe(req.StartKey),
   349  			maybeBytesToMaybe(req.EndKey),
   350  			keyLimit,
   351  		)
   352  		if err != nil {
   353  			if errors.Is(err, merkledb.ErrInsufficientHistory) {
   354  				return nil, nil // drop request
   355  			}
   356  			return nil, err
   357  		}
   358  
   359  		proofBytes, err := marshalFunc(rangeProof)
   360  		if err != nil {
   361  			return nil, err
   362  		}
   363  
   364  		if len(proofBytes) < int(req.BytesLimit) {
   365  			return proofBytes, nil
   366  		}
   367  
   368  		// The proof was too large. Try to shrink it.
   369  		keyLimit = len(rangeProof.KeyValues) / 2
   370  	}
   371  	return nil, ErrMinProofSizeIsTooLarge
   372  }
   373  
   374  // isTimeout returns true if err is a timeout from a context cancellation
   375  // or a context cancellation over grpc.
   376  func isTimeout(err error) bool {
   377  	// handle grpc wrapped DeadlineExceeded
   378  	if e, ok := status.FromError(err); ok {
   379  		if e.Code() == codes.DeadlineExceeded {
   380  			return true
   381  		}
   382  	}
   383  	// otherwise, check for context.DeadlineExceeded directly
   384  	return errors.Is(err, context.DeadlineExceeded)
   385  }
   386  
   387  // Returns nil iff [req] is well-formed.
   388  func validateChangeProofRequest(req *pb.SyncGetChangeProofRequest) error {
   389  	switch {
   390  	case req.BytesLimit == 0:
   391  		return errInvalidBytesLimit
   392  	case req.KeyLimit == 0:
   393  		return errInvalidKeyLimit
   394  	case len(req.StartRootHash) != hashing.HashLen:
   395  		return errInvalidStartRootHash
   396  	case len(req.EndRootHash) != hashing.HashLen:
   397  		return errInvalidEndRootHash
   398  	case bytes.Equal(req.EndRootHash, ids.Empty[:]):
   399  		return merkledb.ErrEmptyProof
   400  	case req.StartKey != nil && req.StartKey.IsNothing && len(req.StartKey.Value) > 0:
   401  		return errInvalidStartKey
   402  	case req.EndKey != nil && req.EndKey.IsNothing && len(req.EndKey.Value) > 0:
   403  		return errInvalidEndKey
   404  	case req.StartKey != nil && req.EndKey != nil && !req.StartKey.IsNothing &&
   405  		!req.EndKey.IsNothing && bytes.Compare(req.StartKey.Value, req.EndKey.Value) > 0:
   406  		return errInvalidBounds
   407  	default:
   408  		return nil
   409  	}
   410  }
   411  
   412  // Returns nil iff [req] is well-formed.
   413  func validateRangeProofRequest(req *pb.SyncGetRangeProofRequest) error {
   414  	switch {
   415  	case req.BytesLimit == 0:
   416  		return errInvalidBytesLimit
   417  	case req.KeyLimit == 0:
   418  		return errInvalidKeyLimit
   419  	case len(req.RootHash) != ids.IDLen:
   420  		return errInvalidRootHash
   421  	case bytes.Equal(req.RootHash, ids.Empty[:]):
   422  		return merkledb.ErrEmptyProof
   423  	case req.StartKey != nil && req.StartKey.IsNothing && len(req.StartKey.Value) > 0:
   424  		return errInvalidStartKey
   425  	case req.EndKey != nil && req.EndKey.IsNothing && len(req.EndKey.Value) > 0:
   426  		return errInvalidEndKey
   427  	case req.StartKey != nil && req.EndKey != nil && !req.StartKey.IsNothing &&
   428  		!req.EndKey.IsNothing && bytes.Compare(req.StartKey.Value, req.EndKey.Value) > 0:
   429  		return errInvalidBounds
   430  	default:
   431  		return nil
   432  	}
   433  }