github.com/status-im/status-go@v1.1.0/protocol/common/message_segmentation.go (about)

     1  package common
     2  
     3  import (
     4  	"bytes"
     5  	"math"
     6  	"time"
     7  
     8  	"github.com/golang/protobuf/proto"
     9  	"github.com/jinzhu/copier"
    10  	"github.com/klauspost/reedsolomon"
    11  	"github.com/pkg/errors"
    12  	"go.uber.org/zap"
    13  
    14  	"github.com/status-im/status-go/eth-node/crypto"
    15  	"github.com/status-im/status-go/eth-node/types"
    16  	"github.com/status-im/status-go/protocol/protobuf"
    17  	v1protocol "github.com/status-im/status-go/protocol/v1"
    18  )
    19  
    20  var ErrMessageSegmentsIncomplete = errors.New("message segments incomplete")
    21  var ErrMessageSegmentsAlreadyCompleted = errors.New("message segments already completed")
    22  var ErrMessageSegmentsInvalidCount = errors.New("invalid segments count")
    23  var ErrMessageSegmentsHashMismatch = errors.New("hash of entire payload does not match")
    24  var ErrMessageSegmentsInvalidParity = errors.New("invalid parity segments")
    25  
    26  const (
    27  	segmentsParityRate          = 0.125
    28  	segmentsReedsolomonMaxCount = 256
    29  )
    30  
    31  type SegmentMessage struct {
    32  	*protobuf.SegmentMessage
    33  }
    34  
    35  func (s *SegmentMessage) IsValid() bool {
    36  	return s.SegmentsCount >= 2 || s.ParitySegmentsCount > 0
    37  }
    38  
    39  func (s *SegmentMessage) IsParityMessage() bool {
    40  	return s.SegmentsCount == 0 && s.ParitySegmentsCount > 0
    41  }
    42  
    43  func (s *MessageSender) segmentMessage(newMessage *types.NewMessage) ([]*types.NewMessage, error) {
    44  	// We set the max message size to 3/4 of the allowed message size, to leave
    45  	// room for segment message metadata.
    46  	newMessages, err := segmentMessage(newMessage, int(s.transport.MaxMessageSize()/4*3))
    47  	s.logger.Debug("message segmented", zap.Int("segments", len(newMessages)))
    48  	return newMessages, err
    49  }
    50  
    51  func replicateMessageWithNewPayload(message *types.NewMessage, payload []byte) (*types.NewMessage, error) {
    52  	copy := &types.NewMessage{}
    53  	err := copier.Copy(copy, message)
    54  	if err != nil {
    55  		return nil, err
    56  	}
    57  
    58  	copy.Payload = payload
    59  	copy.PowTarget = calculatePoW(payload)
    60  	return copy, nil
    61  }
    62  
    63  // Segments message into smaller chunks if the size exceeds segmentSize.
    64  func segmentMessage(newMessage *types.NewMessage, segmentSize int) ([]*types.NewMessage, error) {
    65  	if len(newMessage.Payload) <= segmentSize {
    66  		return []*types.NewMessage{newMessage}, nil
    67  	}
    68  
    69  	entireMessageHash := crypto.Keccak256(newMessage.Payload)
    70  	entirePayloadSize := len(newMessage.Payload)
    71  
    72  	segmentsCount := int(math.Ceil(float64(entirePayloadSize) / float64(segmentSize)))
    73  	paritySegmentsCount := int(math.Floor(float64(segmentsCount) * segmentsParityRate))
    74  
    75  	segmentPayloads := make([][]byte, segmentsCount+paritySegmentsCount)
    76  	segmentMessages := make([]*types.NewMessage, segmentsCount)
    77  
    78  	for start, index := 0, 0; start < entirePayloadSize; start += segmentSize {
    79  		end := start + segmentSize
    80  		if end > entirePayloadSize {
    81  			end = entirePayloadSize
    82  		}
    83  
    84  		segmentPayload := newMessage.Payload[start:end]
    85  		segmentWithMetadata := &protobuf.SegmentMessage{
    86  			EntireMessageHash: entireMessageHash,
    87  			Index:             uint32(index),
    88  			SegmentsCount:     uint32(segmentsCount),
    89  			Payload:           segmentPayload,
    90  		}
    91  		marshaledSegmentWithMetadata, err := proto.Marshal(segmentWithMetadata)
    92  		if err != nil {
    93  			return nil, err
    94  		}
    95  		segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata)
    96  		if err != nil {
    97  			return nil, err
    98  		}
    99  
   100  		segmentPayloads[index] = segmentPayload
   101  		segmentMessages[index] = segmentMessage
   102  		index++
   103  	}
   104  
   105  	// Skip reedsolomon if the combined total of data and parity segments exceeds the predefined limit of segmentsReedsolomonMaxCount.
   106  	// Exceeding this limit necessitates shard sizes to be multiples of 64, which are incompatible with clients that do not support forward error correction.
   107  	if paritySegmentsCount == 0 || segmentsCount+paritySegmentsCount > segmentsReedsolomonMaxCount {
   108  		return segmentMessages, nil
   109  	}
   110  
   111  	enc, err := reedsolomon.New(segmentsCount, paritySegmentsCount)
   112  	if err != nil {
   113  		return nil, err
   114  	}
   115  
   116  	// Align the size of the last segment payload.
   117  	lastSegmentPayload := segmentPayloads[segmentsCount-1]
   118  	segmentPayloads[segmentsCount-1] = make([]byte, segmentSize)
   119  	copy(segmentPayloads[segmentsCount-1], lastSegmentPayload)
   120  
   121  	// Make space for parity data.
   122  	for i := segmentsCount; i < segmentsCount+paritySegmentsCount; i++ {
   123  		segmentPayloads[i] = make([]byte, segmentSize)
   124  	}
   125  
   126  	err = enc.Encode(segmentPayloads)
   127  	if err != nil {
   128  		return nil, err
   129  	}
   130  
   131  	// Create parity messages.
   132  	for i, index := segmentsCount, 0; i < segmentsCount+paritySegmentsCount; i++ {
   133  		segmentWithMetadata := &protobuf.SegmentMessage{
   134  			EntireMessageHash:   entireMessageHash,
   135  			SegmentsCount:       0, // indicates parity message
   136  			ParitySegmentIndex:  uint32(index),
   137  			ParitySegmentsCount: uint32(paritySegmentsCount),
   138  			Payload:             segmentPayloads[i],
   139  		}
   140  		marshaledSegmentWithMetadata, err := proto.Marshal(segmentWithMetadata)
   141  		if err != nil {
   142  			return nil, err
   143  		}
   144  		segmentMessage, err := replicateMessageWithNewPayload(newMessage, marshaledSegmentWithMetadata)
   145  		if err != nil {
   146  			return nil, err
   147  		}
   148  
   149  		segmentMessages = append(segmentMessages, segmentMessage)
   150  		index++
   151  	}
   152  
   153  	return segmentMessages, nil
   154  }
   155  
   156  // SegmentationLayerV1 reconstructs the message only when all segments have been successfully retrieved.
   157  // It lacks the capability to perform forward error correction.
   158  // Kept to test forward compatibility.
   159  func (s *MessageSender) handleSegmentationLayerV1(message *v1protocol.StatusMessage) error {
   160  	logger := s.logger.With(zap.String("site", "handleSegmentationLayerV1")).With(zap.String("hash", types.HexBytes(message.TransportLayer.Hash).String()))
   161  
   162  	segmentMessage := &SegmentMessage{
   163  		SegmentMessage: &protobuf.SegmentMessage{},
   164  	}
   165  	err := proto.Unmarshal(message.TransportLayer.Payload, segmentMessage.SegmentMessage)
   166  	if err != nil {
   167  		return errors.Wrap(err, "failed to unmarshal SegmentMessage")
   168  	}
   169  
   170  	logger.Debug("handling message segment", zap.String("EntireMessageHash", types.HexBytes(segmentMessage.EntireMessageHash).String()),
   171  		zap.Uint32("Index", segmentMessage.Index), zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount))
   172  
   173  	alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash)
   174  	if err != nil {
   175  		return err
   176  	}
   177  	if alreadyCompleted {
   178  		return ErrMessageSegmentsAlreadyCompleted
   179  	}
   180  
   181  	if segmentMessage.SegmentsCount < 2 {
   182  		return ErrMessageSegmentsInvalidCount
   183  	}
   184  
   185  	err = s.persistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix())
   186  	if err != nil {
   187  		return err
   188  	}
   189  
   190  	segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey)
   191  	if err != nil {
   192  		return err
   193  	}
   194  
   195  	if len(segments) != int(segmentMessage.SegmentsCount) {
   196  		return ErrMessageSegmentsIncomplete
   197  	}
   198  
   199  	// Combine payload
   200  	var entirePayload bytes.Buffer
   201  	for _, segment := range segments {
   202  		_, err := entirePayload.Write(segment.Payload)
   203  		if err != nil {
   204  			return errors.Wrap(err, "failed to write segment payload")
   205  		}
   206  	}
   207  
   208  	// Sanity check
   209  	entirePayloadHash := crypto.Keccak256(entirePayload.Bytes())
   210  	if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) {
   211  		return ErrMessageSegmentsHashMismatch
   212  	}
   213  
   214  	err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix())
   215  	if err != nil {
   216  		return err
   217  	}
   218  
   219  	message.TransportLayer.Payload = entirePayload.Bytes()
   220  
   221  	return nil
   222  }
   223  
   224  // SegmentationLayerV2 is capable of reconstructing the message from both complete and partial sets of data segments.
   225  // It has capability to perform forward error correction.
   226  func (s *MessageSender) handleSegmentationLayerV2(message *v1protocol.StatusMessage) error {
   227  	logger := s.logger.With(zap.String("site", "handleSegmentationLayerV2")).With(zap.String("hash", types.HexBytes(message.TransportLayer.Hash).String()))
   228  
   229  	segmentMessage := &SegmentMessage{
   230  		SegmentMessage: &protobuf.SegmentMessage{},
   231  	}
   232  	err := proto.Unmarshal(message.TransportLayer.Payload, segmentMessage.SegmentMessage)
   233  	if err != nil {
   234  		return errors.Wrap(err, "failed to unmarshal SegmentMessage")
   235  	}
   236  
   237  	logger.Debug("handling message segment",
   238  		zap.String("EntireMessageHash", types.HexBytes(segmentMessage.EntireMessageHash).String()),
   239  		zap.Uint32("Index", segmentMessage.Index),
   240  		zap.Uint32("SegmentsCount", segmentMessage.SegmentsCount),
   241  		zap.Uint32("ParitySegmentIndex", segmentMessage.ParitySegmentIndex),
   242  		zap.Uint32("ParitySegmentsCount", segmentMessage.ParitySegmentsCount))
   243  
   244  	alreadyCompleted, err := s.persistence.IsMessageAlreadyCompleted(segmentMessage.EntireMessageHash)
   245  	if err != nil {
   246  		return err
   247  	}
   248  	if alreadyCompleted {
   249  		return ErrMessageSegmentsAlreadyCompleted
   250  	}
   251  
   252  	if !segmentMessage.IsValid() {
   253  		return ErrMessageSegmentsInvalidCount
   254  	}
   255  
   256  	err = s.persistence.SaveMessageSegment(segmentMessage, message.TransportLayer.SigPubKey, time.Now().Unix())
   257  	if err != nil {
   258  		return err
   259  	}
   260  
   261  	segments, err := s.persistence.GetMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey)
   262  	if err != nil {
   263  		return err
   264  	}
   265  
   266  	if len(segments) == 0 {
   267  		return errors.New("unexpected state: no segments found after save operation") // This should theoretically never occur.
   268  	}
   269  
   270  	firstSegmentMessage := segments[0]
   271  	lastSegmentMessage := segments[len(segments)-1]
   272  
   273  	// First segment message must not be a parity message.
   274  	if firstSegmentMessage.IsParityMessage() || len(segments) != int(firstSegmentMessage.SegmentsCount) {
   275  		return ErrMessageSegmentsIncomplete
   276  	}
   277  
   278  	payloads := make([][]byte, firstSegmentMessage.SegmentsCount+lastSegmentMessage.ParitySegmentsCount)
   279  	payloadSize := len(firstSegmentMessage.Payload)
   280  
   281  	restoreUsingParityData := lastSegmentMessage.IsParityMessage()
   282  	if !restoreUsingParityData {
   283  		for i, segment := range segments {
   284  			payloads[i] = segment.Payload
   285  		}
   286  	} else {
   287  		enc, err := reedsolomon.New(int(firstSegmentMessage.SegmentsCount), int(lastSegmentMessage.ParitySegmentsCount))
   288  		if err != nil {
   289  			return err
   290  		}
   291  
   292  		var lastNonParitySegmentPayload []byte
   293  		for _, segment := range segments {
   294  			if !segment.IsParityMessage() {
   295  				if segment.Index == firstSegmentMessage.SegmentsCount-1 {
   296  					// Ensure last segment is aligned to payload size, as it is required by reedsolomon.
   297  					payloads[segment.Index] = make([]byte, payloadSize)
   298  					copy(payloads[segment.Index], segment.Payload)
   299  					lastNonParitySegmentPayload = segment.Payload
   300  				} else {
   301  					payloads[segment.Index] = segment.Payload
   302  				}
   303  			} else {
   304  				payloads[firstSegmentMessage.SegmentsCount+segment.ParitySegmentIndex] = segment.Payload
   305  			}
   306  		}
   307  
   308  		err = enc.Reconstruct(payloads)
   309  		if err != nil {
   310  			return err
   311  		}
   312  
   313  		ok, err := enc.Verify(payloads)
   314  		if err != nil {
   315  			return err
   316  		}
   317  		if !ok {
   318  			return ErrMessageSegmentsInvalidParity
   319  		}
   320  
   321  		if lastNonParitySegmentPayload != nil {
   322  			payloads[firstSegmentMessage.SegmentsCount-1] = lastNonParitySegmentPayload // Bring back last segment with original length.
   323  		}
   324  	}
   325  
   326  	// Combine payload.
   327  	var entirePayload bytes.Buffer
   328  	for i := 0; i < int(firstSegmentMessage.SegmentsCount); i++ {
   329  		_, err := entirePayload.Write(payloads[i])
   330  		if err != nil {
   331  			return errors.Wrap(err, "failed to write segment payload")
   332  		}
   333  	}
   334  
   335  	// Sanity check.
   336  	entirePayloadHash := crypto.Keccak256(entirePayload.Bytes())
   337  	if !bytes.Equal(entirePayloadHash, segmentMessage.EntireMessageHash) {
   338  		return ErrMessageSegmentsHashMismatch
   339  	}
   340  
   341  	err = s.persistence.CompleteMessageSegments(segmentMessage.EntireMessageHash, message.TransportLayer.SigPubKey, time.Now().Unix())
   342  	if err != nil {
   343  		return err
   344  	}
   345  
   346  	message.TransportLayer.Payload = entirePayload.Bytes()
   347  
   348  	return nil
   349  }
   350  
   351  func (s *MessageSender) CleanupSegments() error {
   352  	monthAgo := time.Now().AddDate(0, -1, 0).Unix()
   353  
   354  	err := s.persistence.RemoveMessageSegmentsOlderThan(monthAgo)
   355  	if err != nil {
   356  		return err
   357  	}
   358  
   359  	err = s.persistence.RemoveMessageSegmentsCompletedOlderThan(monthAgo)
   360  	if err != nil {
   361  		return err
   362  	}
   363  
   364  	return nil
   365  }