github.com/status-im/status-go@v1.1.0/protocol/common/message_segmentation_test.go (about)

     1  package common
     2  
     3  import (
     4  	"fmt"
     5  	"math"
     6  	"testing"
     7  
     8  	"github.com/stretchr/testify/suite"
     9  	"go.uber.org/zap"
    10  	"golang.org/x/exp/slices"
    11  
    12  	"github.com/status-im/status-go/appdatabase"
    13  	"github.com/status-im/status-go/eth-node/crypto"
    14  	"github.com/status-im/status-go/eth-node/types"
    15  	"github.com/status-im/status-go/protocol/sqlite"
    16  	"github.com/status-im/status-go/protocol/v1"
    17  	"github.com/status-im/status-go/t/helpers"
    18  )
    19  
    20  func TestMessageSegmentationSuite(t *testing.T) {
    21  	suite.Run(t, new(MessageSegmentationSuite))
    22  }
    23  
    24  type MessageSegmentationSuite struct {
    25  	suite.Suite
    26  
    27  	sender      *MessageSender
    28  	testPayload []byte
    29  	logger      *zap.Logger
    30  }
    31  
    32  func (s *MessageSegmentationSuite) SetupSuite() {
    33  	s.testPayload = make([]byte, 1000)
    34  	for i := 0; i < 1000; i++ {
    35  		s.testPayload[i] = byte(i)
    36  	}
    37  }
    38  
    39  func (s *MessageSegmentationSuite) SetupTest() {
    40  	identity, err := crypto.GenerateKey()
    41  	s.Require().NoError(err)
    42  
    43  	database, err := helpers.SetupTestMemorySQLDB(appdatabase.DbInitializer{})
    44  	s.Require().NoError(err)
    45  	err = sqlite.Migrate(database)
    46  	s.Require().NoError(err)
    47  
    48  	s.logger, err = zap.NewDevelopment()
    49  	s.Require().NoError(err)
    50  
    51  	s.sender, err = NewMessageSender(
    52  		identity,
    53  		database,
    54  		nil,
    55  		nil,
    56  		s.logger,
    57  		FeatureFlags{},
    58  	)
    59  	s.Require().NoError(err)
    60  }
    61  
    62  func (s *MessageSegmentationSuite) SetupSubTest() {
    63  	s.SetupTest()
    64  }
    65  
    66  func (s *MessageSegmentationSuite) TestHandleSegmentationLayer() {
    67  	testCases := []struct {
    68  		name                             string
    69  		segmentsCount                    int
    70  		expectedParitySegmentsCount      int
    71  		retrievedSegments                []int
    72  		retrievedParitySegments          []int
    73  		segmentationLayerV1ShouldSucceed bool
    74  		segmentationLayerV2ShouldSucceed bool
    75  	}{
    76  		{
    77  			name:                             "all segments retrieved",
    78  			segmentsCount:                    2,
    79  			expectedParitySegmentsCount:      0,
    80  			retrievedSegments:                []int{0, 1},
    81  			retrievedParitySegments:          []int{},
    82  			segmentationLayerV1ShouldSucceed: true,
    83  			segmentationLayerV2ShouldSucceed: true,
    84  		},
    85  		{
    86  			name:                             "all segments retrieved out of order",
    87  			segmentsCount:                    2,
    88  			expectedParitySegmentsCount:      0,
    89  			retrievedSegments:                []int{1, 0},
    90  			retrievedParitySegments:          []int{},
    91  			segmentationLayerV1ShouldSucceed: true,
    92  			segmentationLayerV2ShouldSucceed: true,
    93  		},
    94  		{
    95  			name:                             "all segments&parity retrieved",
    96  			segmentsCount:                    8,
    97  			expectedParitySegmentsCount:      1,
    98  			retrievedSegments:                []int{0, 1, 2, 3, 4, 5, 6, 7, 8},
    99  			retrievedParitySegments:          []int{8},
   100  			segmentationLayerV1ShouldSucceed: true,
   101  			segmentationLayerV2ShouldSucceed: true,
   102  		},
   103  		{
   104  			name:                             "all segments&parity retrieved out of order",
   105  			segmentsCount:                    8,
   106  			expectedParitySegmentsCount:      1,
   107  			retrievedSegments:                []int{8, 0, 7, 1, 6, 2, 5, 3, 4},
   108  			retrievedParitySegments:          []int{8},
   109  			segmentationLayerV1ShouldSucceed: true,
   110  			segmentationLayerV2ShouldSucceed: true,
   111  		},
   112  		{
   113  			name:                             "no segments retrieved",
   114  			segmentsCount:                    2,
   115  			expectedParitySegmentsCount:      0,
   116  			retrievedSegments:                []int{},
   117  			retrievedParitySegments:          []int{},
   118  			segmentationLayerV1ShouldSucceed: false,
   119  			segmentationLayerV2ShouldSucceed: false,
   120  		},
   121  		{
   122  			name:                             "not all needed segments&parity retrieved",
   123  			segmentsCount:                    8,
   124  			expectedParitySegmentsCount:      1,
   125  			retrievedSegments:                []int{1, 2, 8},
   126  			retrievedParitySegments:          []int{8},
   127  			segmentationLayerV1ShouldSucceed: false,
   128  			segmentationLayerV2ShouldSucceed: false,
   129  		},
   130  		{
   131  			name:                             "segments&parity retrieved",
   132  			segmentsCount:                    8,
   133  			expectedParitySegmentsCount:      1,
   134  			retrievedSegments:                []int{1, 2, 3, 4, 5, 6, 7, 8},
   135  			retrievedParitySegments:          []int{8},
   136  			segmentationLayerV1ShouldSucceed: false,
   137  			segmentationLayerV2ShouldSucceed: true, // succeed even though one segment is missing, thank you reedsolomon
   138  		},
   139  		{
   140  			name:                             "segments&parity retrieved out of order",
   141  			segmentsCount:                    16,
   142  			expectedParitySegmentsCount:      2,
   143  			retrievedSegments:                []int{17, 0, 16, 1, 15, 2, 14, 3, 13, 4, 12, 5, 11, 6, 10, 7},
   144  			retrievedParitySegments:          []int{16, 17},
   145  			segmentationLayerV1ShouldSucceed: false,
   146  			segmentationLayerV2ShouldSucceed: true, // succeed even though two segments are missing, thank you reedsolomon
   147  		},
   148  	}
   149  
   150  	for _, version := range []string{"V1", "V2"} {
   151  		for _, tc := range testCases {
   152  			s.Run(fmt.Sprintf("%s %s", version, tc.name), func() {
   153  				segmentedMessages, err := segmentMessage(&types.NewMessage{Payload: s.testPayload}, int(math.Ceil(float64(len(s.testPayload))/float64(tc.segmentsCount))))
   154  				s.Require().NoError(err)
   155  				s.Require().Len(segmentedMessages, tc.segmentsCount+tc.expectedParitySegmentsCount)
   156  
   157  				message := &protocol.StatusMessage{TransportLayer: protocol.TransportLayer{
   158  					SigPubKey: &s.sender.identity.PublicKey,
   159  				}}
   160  
   161  				messageRecreated := false
   162  				handledSegments := []int{}
   163  
   164  				for i, segmentIndex := range tc.retrievedSegments {
   165  					s.T().Log("i=", i, "segmentIndex=", segmentIndex)
   166  
   167  					message.TransportLayer.Payload = segmentedMessages[segmentIndex].Payload
   168  
   169  					if version == "V1" {
   170  						err = s.sender.handleSegmentationLayerV1(message)
   171  						// V1 is unable to handle parity segment
   172  						if slices.Contains(tc.retrievedParitySegments, segmentIndex) {
   173  							if len(handledSegments) >= tc.segmentsCount {
   174  								s.Require().ErrorIs(err, ErrMessageSegmentsAlreadyCompleted)
   175  							} else {
   176  								s.Require().ErrorIs(err, ErrMessageSegmentsInvalidCount)
   177  							}
   178  							continue
   179  						}
   180  					} else {
   181  						err = s.sender.handleSegmentationLayerV2(message)
   182  					}
   183  
   184  					handledSegments = append(handledSegments, segmentIndex)
   185  
   186  					if len(handledSegments) < tc.segmentsCount {
   187  						s.Require().ErrorIs(err, ErrMessageSegmentsIncomplete)
   188  					} else if len(handledSegments) == tc.segmentsCount {
   189  						s.Require().NoError(err)
   190  						s.Require().ElementsMatch(s.testPayload, message.TransportLayer.Payload)
   191  						messageRecreated = true
   192  					} else {
   193  						s.Require().ErrorIs(err, ErrMessageSegmentsAlreadyCompleted)
   194  					}
   195  				}
   196  
   197  				if version == "V1" {
   198  					s.Require().Equal(tc.segmentationLayerV1ShouldSucceed, messageRecreated)
   199  				} else {
   200  					s.Require().Equal(tc.segmentationLayerV2ShouldSucceed, messageRecreated)
   201  				}
   202  			})
   203  		}
   204  	}
   205  }