github.com/MetalBlockchain/metalgo@v1.11.9/message/inbound_msg_builder_test.go (about)

     1  // Copyright (C) 2019-2024, Ava Labs, Inc. All rights reserved.
     2  // See the file LICENSE for licensing terms.
     3  
     4  package message
     5  
     6  import (
     7  	"testing"
     8  	"time"
     9  
    10  	"github.com/prometheus/client_golang/prometheus"
    11  	"github.com/stretchr/testify/require"
    12  
    13  	"github.com/MetalBlockchain/metalgo/ids"
    14  	"github.com/MetalBlockchain/metalgo/proto/pb/p2p"
    15  	"github.com/MetalBlockchain/metalgo/utils/compression"
    16  	"github.com/MetalBlockchain/metalgo/utils/logging"
    17  	"github.com/MetalBlockchain/metalgo/utils/timer/mockable"
    18  )
    19  
    20  func Test_newMsgBuilder(t *testing.T) {
    21  	t.Parallel()
    22  	require := require.New(t)
    23  
    24  	mb, err := newMsgBuilder(
    25  		logging.NoLog{},
    26  		prometheus.NewRegistry(),
    27  		10*time.Second,
    28  	)
    29  	require.NoError(err)
    30  	require.NotNil(mb)
    31  }
    32  
    33  func TestInboundMsgBuilder(t *testing.T) {
    34  	var (
    35  		chainID                    = ids.GenerateTestID()
    36  		requestID           uint32 = 12345
    37  		deadline                   = time.Hour
    38  		nodeID                     = ids.GenerateTestNodeID()
    39  		summary                    = []byte{9, 8, 7}
    40  		appBytes                   = []byte{1, 3, 3, 7}
    41  		container                  = []byte{1, 2, 3, 4, 5, 6, 7, 8, 9}
    42  		containerIDs               = []ids.ID{ids.GenerateTestID(), ids.GenerateTestID()}
    43  		requestedHeight     uint64 = 999
    44  		acceptedContainerID        = ids.GenerateTestID()
    45  		summaryIDs                 = []ids.ID{ids.GenerateTestID(), ids.GenerateTestID()}
    46  		heights                    = []uint64{1000, 2000}
    47  	)
    48  
    49  	t.Run(
    50  		"InboundGetStateSummaryFrontier",
    51  		func(t *testing.T) {
    52  			require := require.New(t)
    53  
    54  			start := time.Now()
    55  			msg := InboundGetStateSummaryFrontier(
    56  				chainID,
    57  				requestID,
    58  				deadline,
    59  				nodeID,
    60  			)
    61  			end := time.Now()
    62  
    63  			require.Equal(GetStateSummaryFrontierOp, msg.Op())
    64  			require.Equal(nodeID, msg.NodeID())
    65  			require.False(msg.Expiration().Before(start.Add(deadline)))
    66  			require.False(end.Add(deadline).Before(msg.Expiration()))
    67  			require.IsType(&p2p.GetStateSummaryFrontier{}, msg.Message())
    68  			innerMsg := msg.Message().(*p2p.GetStateSummaryFrontier)
    69  			require.Equal(chainID[:], innerMsg.ChainId)
    70  			require.Equal(requestID, innerMsg.RequestId)
    71  		},
    72  	)
    73  
    74  	t.Run(
    75  		"InboundStateSummaryFrontier",
    76  		func(t *testing.T) {
    77  			require := require.New(t)
    78  
    79  			msg := InboundStateSummaryFrontier(
    80  				chainID,
    81  				requestID,
    82  				summary,
    83  				nodeID,
    84  			)
    85  
    86  			require.Equal(StateSummaryFrontierOp, msg.Op())
    87  			require.Equal(nodeID, msg.NodeID())
    88  			require.Equal(mockable.MaxTime, msg.Expiration())
    89  			require.IsType(&p2p.StateSummaryFrontier{}, msg.Message())
    90  			innerMsg := msg.Message().(*p2p.StateSummaryFrontier)
    91  			require.Equal(chainID[:], innerMsg.ChainId)
    92  			require.Equal(requestID, innerMsg.RequestId)
    93  			require.Equal(summary, innerMsg.Summary)
    94  		},
    95  	)
    96  
    97  	t.Run(
    98  		"InboundGetAcceptedStateSummary",
    99  		func(t *testing.T) {
   100  			require := require.New(t)
   101  
   102  			start := time.Now()
   103  			msg := InboundGetAcceptedStateSummary(
   104  				chainID,
   105  				requestID,
   106  				heights,
   107  				deadline,
   108  				nodeID,
   109  			)
   110  			end := time.Now()
   111  
   112  			require.Equal(GetAcceptedStateSummaryOp, msg.Op())
   113  			require.Equal(nodeID, msg.NodeID())
   114  			require.False(msg.Expiration().Before(start.Add(deadline)))
   115  			require.False(end.Add(deadline).Before(msg.Expiration()))
   116  			require.IsType(&p2p.GetAcceptedStateSummary{}, msg.Message())
   117  			innerMsg := msg.Message().(*p2p.GetAcceptedStateSummary)
   118  			require.Equal(chainID[:], innerMsg.ChainId)
   119  			require.Equal(requestID, innerMsg.RequestId)
   120  			require.Equal(heights, innerMsg.Heights)
   121  		},
   122  	)
   123  
   124  	t.Run(
   125  		"InboundAcceptedStateSummary",
   126  		func(t *testing.T) {
   127  			require := require.New(t)
   128  
   129  			msg := InboundAcceptedStateSummary(
   130  				chainID,
   131  				requestID,
   132  				summaryIDs,
   133  				nodeID,
   134  			)
   135  
   136  			require.Equal(AcceptedStateSummaryOp, msg.Op())
   137  			require.Equal(nodeID, msg.NodeID())
   138  			require.Equal(mockable.MaxTime, msg.Expiration())
   139  			require.IsType(&p2p.AcceptedStateSummary{}, msg.Message())
   140  			innerMsg := msg.Message().(*p2p.AcceptedStateSummary)
   141  			require.Equal(chainID[:], innerMsg.ChainId)
   142  			require.Equal(requestID, innerMsg.RequestId)
   143  			summaryIDsBytes := make([][]byte, len(summaryIDs))
   144  			for i, id := range summaryIDs {
   145  				id := id
   146  				summaryIDsBytes[i] = id[:]
   147  			}
   148  			require.Equal(summaryIDsBytes, innerMsg.SummaryIds)
   149  		},
   150  	)
   151  
   152  	t.Run(
   153  		"InboundGetAcceptedFrontier",
   154  		func(t *testing.T) {
   155  			require := require.New(t)
   156  
   157  			start := time.Now()
   158  			msg := InboundGetAcceptedFrontier(
   159  				chainID,
   160  				requestID,
   161  				deadline,
   162  				nodeID,
   163  			)
   164  			end := time.Now()
   165  
   166  			require.Equal(GetAcceptedFrontierOp, msg.Op())
   167  			require.Equal(nodeID, msg.NodeID())
   168  			require.False(msg.Expiration().Before(start.Add(deadline)))
   169  			require.False(end.Add(deadline).Before(msg.Expiration()))
   170  			require.IsType(&p2p.GetAcceptedFrontier{}, msg.Message())
   171  			innerMsg := msg.Message().(*p2p.GetAcceptedFrontier)
   172  			require.Equal(chainID[:], innerMsg.ChainId)
   173  			require.Equal(requestID, innerMsg.RequestId)
   174  		},
   175  	)
   176  
   177  	t.Run(
   178  		"InboundAcceptedFrontier",
   179  		func(t *testing.T) {
   180  			require := require.New(t)
   181  
   182  			msg := InboundAcceptedFrontier(
   183  				chainID,
   184  				requestID,
   185  				containerIDs[0],
   186  				nodeID,
   187  			)
   188  
   189  			require.Equal(AcceptedFrontierOp, msg.Op())
   190  			require.Equal(nodeID, msg.NodeID())
   191  			require.Equal(mockable.MaxTime, msg.Expiration())
   192  			require.IsType(&p2p.AcceptedFrontier{}, msg.Message())
   193  			innerMsg := msg.Message().(*p2p.AcceptedFrontier)
   194  			require.Equal(chainID[:], innerMsg.ChainId)
   195  			require.Equal(requestID, innerMsg.RequestId)
   196  			require.Equal(containerIDs[0][:], innerMsg.ContainerId)
   197  		},
   198  	)
   199  
   200  	t.Run(
   201  		"InboundGetAccepted",
   202  		func(t *testing.T) {
   203  			require := require.New(t)
   204  
   205  			start := time.Now()
   206  			msg := InboundGetAccepted(
   207  				chainID,
   208  				requestID,
   209  				deadline,
   210  				containerIDs,
   211  				nodeID,
   212  			)
   213  			end := time.Now()
   214  
   215  			require.Equal(GetAcceptedOp, msg.Op())
   216  			require.Equal(nodeID, msg.NodeID())
   217  			require.False(msg.Expiration().Before(start.Add(deadline)))
   218  			require.False(end.Add(deadline).Before(msg.Expiration()))
   219  			require.IsType(&p2p.GetAccepted{}, msg.Message())
   220  			innerMsg := msg.Message().(*p2p.GetAccepted)
   221  			require.Equal(chainID[:], innerMsg.ChainId)
   222  			require.Equal(requestID, innerMsg.RequestId)
   223  		},
   224  	)
   225  
   226  	t.Run(
   227  		"InboundAccepted",
   228  		func(t *testing.T) {
   229  			require := require.New(t)
   230  
   231  			msg := InboundAccepted(
   232  				chainID,
   233  				requestID,
   234  				containerIDs,
   235  				nodeID,
   236  			)
   237  
   238  			require.Equal(AcceptedOp, msg.Op())
   239  			require.Equal(nodeID, msg.NodeID())
   240  			require.Equal(mockable.MaxTime, msg.Expiration())
   241  			require.IsType(&p2p.Accepted{}, msg.Message())
   242  			innerMsg := msg.Message().(*p2p.Accepted)
   243  			require.Equal(chainID[:], innerMsg.ChainId)
   244  			require.Equal(requestID, innerMsg.RequestId)
   245  			containerIDsBytes := make([][]byte, len(containerIDs))
   246  			for i, id := range containerIDs {
   247  				id := id
   248  				containerIDsBytes[i] = id[:]
   249  			}
   250  			require.Equal(containerIDsBytes, innerMsg.ContainerIds)
   251  		},
   252  	)
   253  
   254  	t.Run(
   255  		"InboundPushQuery",
   256  		func(t *testing.T) {
   257  			require := require.New(t)
   258  
   259  			start := time.Now()
   260  			msg := InboundPushQuery(
   261  				chainID,
   262  				requestID,
   263  				deadline,
   264  				container,
   265  				requestedHeight,
   266  				nodeID,
   267  			)
   268  			end := time.Now()
   269  
   270  			require.Equal(PushQueryOp, msg.Op())
   271  			require.Equal(nodeID, msg.NodeID())
   272  			require.False(msg.Expiration().Before(start.Add(deadline)))
   273  			require.False(end.Add(deadline).Before(msg.Expiration()))
   274  			require.IsType(&p2p.PushQuery{}, msg.Message())
   275  			innerMsg := msg.Message().(*p2p.PushQuery)
   276  			require.Equal(chainID[:], innerMsg.ChainId)
   277  			require.Equal(requestID, innerMsg.RequestId)
   278  			require.Equal(container, innerMsg.Container)
   279  			require.Equal(requestedHeight, innerMsg.RequestedHeight)
   280  		},
   281  	)
   282  
   283  	t.Run(
   284  		"InboundPullQuery",
   285  		func(t *testing.T) {
   286  			require := require.New(t)
   287  
   288  			start := time.Now()
   289  			msg := InboundPullQuery(
   290  				chainID,
   291  				requestID,
   292  				deadline,
   293  				containerIDs[0],
   294  				requestedHeight,
   295  				nodeID,
   296  			)
   297  			end := time.Now()
   298  
   299  			require.Equal(PullQueryOp, msg.Op())
   300  			require.Equal(nodeID, msg.NodeID())
   301  			require.False(msg.Expiration().Before(start.Add(deadline)))
   302  			require.False(end.Add(deadline).Before(msg.Expiration()))
   303  			require.IsType(&p2p.PullQuery{}, msg.Message())
   304  			innerMsg := msg.Message().(*p2p.PullQuery)
   305  			require.Equal(chainID[:], innerMsg.ChainId)
   306  			require.Equal(requestID, innerMsg.RequestId)
   307  			require.Equal(containerIDs[0][:], innerMsg.ContainerId)
   308  			require.Equal(requestedHeight, innerMsg.RequestedHeight)
   309  		},
   310  	)
   311  
   312  	t.Run(
   313  		"InboundChits",
   314  		func(t *testing.T) {
   315  			require := require.New(t)
   316  
   317  			msg := InboundChits(
   318  				chainID,
   319  				requestID,
   320  				containerIDs[0],
   321  				containerIDs[1],
   322  				acceptedContainerID,
   323  				nodeID,
   324  			)
   325  
   326  			require.Equal(ChitsOp, msg.Op())
   327  			require.Equal(nodeID, msg.NodeID())
   328  			require.Equal(mockable.MaxTime, msg.Expiration())
   329  			require.IsType(&p2p.Chits{}, msg.Message())
   330  			innerMsg := msg.Message().(*p2p.Chits)
   331  			require.Equal(chainID[:], innerMsg.ChainId)
   332  			require.Equal(requestID, innerMsg.RequestId)
   333  			require.Equal(containerIDs[0][:], innerMsg.PreferredId)
   334  			require.Equal(containerIDs[1][:], innerMsg.PreferredIdAtHeight)
   335  			require.Equal(acceptedContainerID[:], innerMsg.AcceptedId)
   336  		},
   337  	)
   338  
   339  	t.Run(
   340  		"InboundAppRequest",
   341  		func(t *testing.T) {
   342  			require := require.New(t)
   343  
   344  			start := time.Now()
   345  			msg := InboundAppRequest(
   346  				chainID,
   347  				requestID,
   348  				deadline,
   349  				appBytes,
   350  				nodeID,
   351  			)
   352  			end := time.Now()
   353  
   354  			require.Equal(AppRequestOp, msg.Op())
   355  			require.Equal(nodeID, msg.NodeID())
   356  			require.False(msg.Expiration().Before(start.Add(deadline)))
   357  			require.False(end.Add(deadline).Before(msg.Expiration()))
   358  			require.IsType(&p2p.AppRequest{}, msg.Message())
   359  			innerMsg := msg.Message().(*p2p.AppRequest)
   360  			require.Equal(chainID[:], innerMsg.ChainId)
   361  			require.Equal(requestID, innerMsg.RequestId)
   362  			require.Equal(appBytes, innerMsg.AppBytes)
   363  		},
   364  	)
   365  
   366  	t.Run(
   367  		"InboundAppResponse",
   368  		func(t *testing.T) {
   369  			require := require.New(t)
   370  
   371  			msg := InboundAppResponse(
   372  				chainID,
   373  				requestID,
   374  				appBytes,
   375  				nodeID,
   376  			)
   377  
   378  			require.Equal(AppResponseOp, msg.Op())
   379  			require.Equal(nodeID, msg.NodeID())
   380  			require.Equal(mockable.MaxTime, msg.Expiration())
   381  			require.IsType(&p2p.AppResponse{}, msg.Message())
   382  			innerMsg := msg.Message().(*p2p.AppResponse)
   383  			require.Equal(chainID[:], innerMsg.ChainId)
   384  			require.Equal(requestID, innerMsg.RequestId)
   385  			require.Equal(appBytes, innerMsg.AppBytes)
   386  		},
   387  	)
   388  }
   389  
   390  func TestAppError(t *testing.T) {
   391  	require := require.New(t)
   392  
   393  	mb, err := newMsgBuilder(
   394  		logging.NoLog{},
   395  		prometheus.NewRegistry(),
   396  		time.Second,
   397  	)
   398  	require.NoError(err)
   399  
   400  	nodeID := ids.GenerateTestNodeID()
   401  	chainID := ids.GenerateTestID()
   402  	requestID := uint32(1)
   403  	errorCode := int32(2)
   404  	errorMessage := "hello world"
   405  
   406  	want := &p2p.Message{
   407  		Message: &p2p.Message_AppError{
   408  			AppError: &p2p.AppError{
   409  				ChainId:      chainID[:],
   410  				RequestId:    requestID,
   411  				ErrorCode:    errorCode,
   412  				ErrorMessage: errorMessage,
   413  			},
   414  		},
   415  	}
   416  
   417  	outMsg, err := mb.createOutbound(want, compression.TypeNone, false)
   418  	require.NoError(err)
   419  
   420  	got, err := mb.parseInbound(outMsg.Bytes(), nodeID, func() {})
   421  	require.NoError(err)
   422  
   423  	require.Equal(nodeID, got.NodeID())
   424  	require.Equal(AppErrorOp, got.Op())
   425  
   426  	msg, ok := got.Message().(*p2p.AppError)
   427  	require.True(ok)
   428  	require.Equal(errorCode, msg.ErrorCode)
   429  	require.Equal(errorMessage, msg.ErrorMessage)
   430  }