github.com/nspcc-dev/neo-go@v0.105.2-0.20240517133400-6be757af3eba/pkg/network/message_test.go (about)

     1  package network
     2  
     3  import (
     4  	"errors"
     5  	"math/rand"
     6  	"testing"
     7  	"time"
     8  
     9  	"github.com/nspcc-dev/neo-go/internal/random"
    10  	"github.com/nspcc-dev/neo-go/internal/testserdes"
    11  	"github.com/nspcc-dev/neo-go/pkg/config/netmode"
    12  	"github.com/nspcc-dev/neo-go/pkg/core/block"
    13  	"github.com/nspcc-dev/neo-go/pkg/core/transaction"
    14  	"github.com/nspcc-dev/neo-go/pkg/io"
    15  	"github.com/nspcc-dev/neo-go/pkg/network/capability"
    16  	"github.com/nspcc-dev/neo-go/pkg/network/payload"
    17  	"github.com/nspcc-dev/neo-go/pkg/util"
    18  	"github.com/stretchr/testify/require"
    19  )
    20  
    21  func TestMessageDecodeFuzzCases(t *testing.T) {
    22  	raw := []byte("10\x0200")
    23  	m := new(Message)
    24  	r := io.NewBinReaderFromBuf(raw)
    25  	require.NotPanics(t, func() { _ = m.Decode(r) })
    26  }
    27  
    28  func TestEncodeDecodeVersion(t *testing.T) {
    29  	// message with tiny payload, shouldn't be compressed
    30  	expected := NewMessage(CMDVersion, &payload.Version{
    31  		Magic:     1,
    32  		Version:   2,
    33  		Timestamp: uint32(time.Now().UnixNano()),
    34  		Nonce:     987,
    35  		UserAgent: []byte{1, 2, 3},
    36  		Capabilities: capability.Capabilities{
    37  			{
    38  				Type: capability.FullNode,
    39  				Data: &capability.Node{
    40  					StartHeight: 123,
    41  				},
    42  			},
    43  		},
    44  	})
    45  	testserdes.EncodeDecode(t, expected, &Message{})
    46  	uncompressed, err := testserdes.EncodeBinary(expected.Payload)
    47  	require.NoError(t, err)
    48  	require.Equal(t, len(expected.compressedPayload), len(uncompressed))
    49  
    50  	// large payload should be compressed
    51  	largeArray := make([]byte, CompressionMinSize)
    52  	for i := range largeArray {
    53  		largeArray[i] = byte(i)
    54  	}
    55  	expected.Payload.(*payload.Version).UserAgent = largeArray
    56  	testserdes.EncodeDecode(t, expected, &Message{})
    57  	uncompressed, err = testserdes.EncodeBinary(expected.Payload)
    58  	require.NoError(t, err)
    59  	require.NotEqual(t, len(expected.compressedPayload), len(uncompressed))
    60  }
    61  
    62  func BenchmarkMessageBytes(b *testing.B) {
    63  	// shouldn't try to compress headers payload
    64  	ep := &payload.Extensible{
    65  		Category:        "consensus",
    66  		ValidBlockStart: rand.Uint32(),
    67  		ValidBlockEnd:   rand.Uint32(),
    68  		Sender:          util.Uint160{},
    69  		Data:            make([]byte, 300),
    70  		Witness: transaction.Witness{
    71  			InvocationScript:   make([]byte, 33),
    72  			VerificationScript: make([]byte, 40),
    73  		},
    74  	}
    75  	random.Fill(ep.Data)
    76  	random.Fill(ep.Witness.InvocationScript)
    77  	random.Fill(ep.Witness.VerificationScript)
    78  	msg := NewMessage(CMDExtensible, ep)
    79  
    80  	b.ReportAllocs()
    81  	b.ResetTimer()
    82  	for i := 0; i < b.N; i++ {
    83  		_, err := msg.Bytes()
    84  		if err != nil {
    85  			b.FailNow()
    86  		}
    87  	}
    88  }
    89  
    90  func TestEncodeDecodeHeaders(t *testing.T) {
    91  	// shouldn't try to compress headers payload
    92  	headers := &payload.Headers{Hdrs: make([]*block.Header, CompressionMinSize)}
    93  	for i := range headers.Hdrs {
    94  		h := &block.Header{
    95  			Index: uint32(i + 1),
    96  			Script: transaction.Witness{
    97  				InvocationScript:   []byte{0x0},
    98  				VerificationScript: []byte{0x1},
    99  			},
   100  		}
   101  		h.Hash()
   102  		headers.Hdrs[i] = h
   103  	}
   104  	expected := NewMessage(CMDHeaders, headers)
   105  	testserdes.EncodeDecode(t, expected, &Message{})
   106  	uncompressed, err := testserdes.EncodeBinary(expected.Payload)
   107  	require.NoError(t, err)
   108  	require.Equal(t, len(expected.compressedPayload), len(uncompressed))
   109  }
   110  
   111  func TestEncodeDecodeGetAddr(t *testing.T) {
   112  	// NullPayload should be handled properly
   113  	testEncodeDecode(t, CMDGetAddr, payload.NewNullPayload())
   114  }
   115  
   116  func TestEncodeDecodeNil(t *testing.T) {
   117  	// nil payload should be decoded into NullPayload
   118  	expected := NewMessage(CMDGetAddr, nil)
   119  	encoded, err := testserdes.Encode(expected)
   120  	require.NoError(t, err)
   121  	decoded := &Message{}
   122  	err = testserdes.Decode(encoded, decoded)
   123  	require.NoError(t, err)
   124  	require.Equal(t, NewMessage(CMDGetAddr, payload.NewNullPayload()), decoded)
   125  }
   126  
   127  func TestEncodeDecodePing(t *testing.T) {
   128  	testEncodeDecode(t, CMDPing, payload.NewPing(123, 456))
   129  }
   130  
   131  func TestEncodeDecodeInventory(t *testing.T) {
   132  	testEncodeDecode(t, CMDInv, payload.NewInventory(payload.ExtensibleType, []util.Uint256{{1, 2, 3}}))
   133  }
   134  
   135  func TestEncodeDecodeAddr(t *testing.T) {
   136  	const count = 3
   137  	p := payload.NewAddressList(count)
   138  	p.Addrs[0] = &payload.AddressAndTime{
   139  		Timestamp: rand.Uint32(),
   140  		Capabilities: capability.Capabilities{{
   141  			Type: capability.FullNode,
   142  			Data: &capability.Node{StartHeight: rand.Uint32()},
   143  		}},
   144  	}
   145  	p.Addrs[1] = &payload.AddressAndTime{
   146  		Timestamp: rand.Uint32(),
   147  		Capabilities: capability.Capabilities{{
   148  			Type: capability.TCPServer,
   149  			Data: &capability.Server{Port: uint16(rand.Uint32())},
   150  		}},
   151  	}
   152  	p.Addrs[2] = &payload.AddressAndTime{
   153  		Timestamp: rand.Uint32(),
   154  		Capabilities: capability.Capabilities{{
   155  			Type: capability.WSServer,
   156  			Data: &capability.Server{Port: uint16(rand.Uint32())},
   157  		}},
   158  	}
   159  	testEncodeDecode(t, CMDAddr, p)
   160  }
   161  
   162  func TestEncodeDecodeBlock(t *testing.T) {
   163  	t.Run("good", func(t *testing.T) {
   164  		testEncodeDecode(t, CMDBlock, newDummyBlock(12, 1))
   165  	})
   166  	t.Run("invalid state root enabled setting", func(t *testing.T) {
   167  		expected := NewMessage(CMDBlock, newDummyBlock(31, 1))
   168  		data, err := testserdes.Encode(expected)
   169  		require.NoError(t, err)
   170  		require.Error(t, testserdes.Decode(data, &Message{StateRootInHeader: true}))
   171  	})
   172  }
   173  
   174  func TestEncodeDecodeGetBlock(t *testing.T) {
   175  	t.Run("good, Count>0", func(t *testing.T) {
   176  		testEncodeDecode(t, CMDGetBlocks, &payload.GetBlocks{
   177  			HashStart: random.Uint256(),
   178  			Count:     int16(rand.Uint32() >> 17),
   179  		})
   180  	})
   181  	t.Run("good, Count=-1", func(t *testing.T) {
   182  		testEncodeDecode(t, CMDGetBlocks, &payload.GetBlocks{
   183  			HashStart: random.Uint256(),
   184  			Count:     -1,
   185  		})
   186  	})
   187  	t.Run("bad, Count=-2", func(t *testing.T) {
   188  		testEncodeDecodeFail(t, CMDGetBlocks, &payload.GetBlocks{
   189  			HashStart: random.Uint256(),
   190  			Count:     -2,
   191  		})
   192  	})
   193  }
   194  
   195  func TestEnodeDecodeGetHeaders(t *testing.T) {
   196  	testEncodeDecode(t, CMDGetHeaders, &payload.GetBlockByIndex{
   197  		IndexStart: rand.Uint32(),
   198  		Count:      payload.MaxHeadersAllowed,
   199  	})
   200  }
   201  
   202  func TestEncodeDecodeGetBlockByIndex(t *testing.T) {
   203  	t.Run("good, Count>0", func(t *testing.T) {
   204  		testEncodeDecode(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{
   205  			IndexStart: rand.Uint32(),
   206  			Count:      payload.MaxHeadersAllowed,
   207  		})
   208  	})
   209  	t.Run("bad, Count too big", func(t *testing.T) {
   210  		testEncodeDecodeFail(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{
   211  			IndexStart: rand.Uint32(),
   212  			Count:      payload.MaxHeadersAllowed + 1,
   213  		})
   214  	})
   215  	t.Run("good, Count=-1", func(t *testing.T) {
   216  		testEncodeDecode(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{
   217  			IndexStart: rand.Uint32(),
   218  			Count:      -1,
   219  		})
   220  	})
   221  	t.Run("bad, Count=-2", func(t *testing.T) {
   222  		testEncodeDecodeFail(t, CMDGetBlockByIndex, &payload.GetBlockByIndex{
   223  			IndexStart: rand.Uint32(),
   224  			Count:      -2,
   225  		})
   226  	})
   227  }
   228  
   229  func TestEncodeDecodeTransaction(t *testing.T) {
   230  	testEncodeDecode(t, CMDTX, newDummyTx())
   231  }
   232  
   233  func TestEncodeDecodeMerkleBlock(t *testing.T) {
   234  	base := &block.Header{
   235  		PrevHash:  random.Uint256(),
   236  		Timestamp: rand.Uint64(),
   237  		Script: transaction.Witness{
   238  			InvocationScript:   random.Bytes(10),
   239  			VerificationScript: random.Bytes(11),
   240  		},
   241  	}
   242  	base.Hash()
   243  	t.Run("good", func(t *testing.T) {
   244  		testEncodeDecode(t, CMDMerkleBlock, &payload.MerkleBlock{
   245  			Header:  base,
   246  			TxCount: 1,
   247  			Hashes:  []util.Uint256{random.Uint256()},
   248  			Flags:   []byte{0},
   249  		})
   250  	})
   251  	t.Run("bad, invalid TxCount", func(t *testing.T) {
   252  		testEncodeDecodeFail(t, CMDMerkleBlock, &payload.MerkleBlock{
   253  			Header:  base,
   254  			TxCount: 2,
   255  			Hashes:  []util.Uint256{random.Uint256()},
   256  			Flags:   []byte{0},
   257  		})
   258  	})
   259  }
   260  
   261  func TestEncodeDecodeNotFound(t *testing.T) {
   262  	testEncodeDecode(t, CMDNotFound, &payload.Inventory{
   263  		Type:   payload.TXType,
   264  		Hashes: []util.Uint256{random.Uint256()},
   265  	})
   266  }
   267  
   268  func TestEncodeDecodeGetMPTData(t *testing.T) {
   269  	testEncodeDecode(t, CMDGetMPTData, &payload.MPTInventory{
   270  		Hashes: []util.Uint256{
   271  			{1, 2, 3},
   272  			{4, 5, 6},
   273  		},
   274  	})
   275  }
   276  
   277  func TestEncodeDecodeMPTData(t *testing.T) {
   278  	testEncodeDecode(t, CMDMPTData, &payload.MPTData{
   279  		Nodes: [][]byte{{1, 2, 3}, {4, 5, 6}},
   280  	})
   281  }
   282  
   283  func TestInvalidMessages(t *testing.T) {
   284  	t.Run("CMDBlock, empty payload", func(t *testing.T) {
   285  		testEncodeDecodeFail(t, CMDBlock, payload.NullPayload{})
   286  	})
   287  	t.Run("send decompressed with flag", func(t *testing.T) {
   288  		m := NewMessage(CMDTX, newDummyTx())
   289  		data, err := testserdes.Encode(m)
   290  		require.NoError(t, err)
   291  		require.True(t, m.Flags&Compressed == 0)
   292  		data[0] |= byte(Compressed)
   293  		require.Error(t, testserdes.Decode(data, &Message{}))
   294  	})
   295  	t.Run("invalid command", func(t *testing.T) {
   296  		testEncodeDecodeFail(t, CommandType(0xFF), &payload.Version{Magic: netmode.UnitTestNet})
   297  	})
   298  	t.Run("very big payload size", func(t *testing.T) {
   299  		m := NewMessage(CMDBlock, nil)
   300  		w := io.NewBufBinWriter()
   301  		w.WriteB(byte(m.Flags))
   302  		w.WriteB(byte(m.Command))
   303  		w.WriteVarBytes(make([]byte, payload.MaxSize+1))
   304  		require.NoError(t, w.Err)
   305  		require.Error(t, testserdes.Decode(w.Bytes(), &Message{}))
   306  	})
   307  	t.Run("fail to encode message if payload can't be serialized", func(t *testing.T) {
   308  		m := NewMessage(CMDBlock, failSer(true))
   309  		_, err := m.Bytes()
   310  		require.Error(t, err)
   311  
   312  		// good otherwise
   313  		m = NewMessage(CMDBlock, failSer(false))
   314  		_, err = m.Bytes()
   315  		require.NoError(t, err)
   316  	})
   317  	t.Run("trimmed payload", func(t *testing.T) {
   318  		m := NewMessage(CMDBlock, newDummyBlock(1, 0))
   319  		data, err := testserdes.Encode(m)
   320  		require.NoError(t, err)
   321  		data = data[:len(data)-1]
   322  		require.Error(t, testserdes.Decode(data, &Message{}))
   323  	})
   324  }
   325  
   326  type failSer bool
   327  
   328  func (f failSer) EncodeBinary(r *io.BinWriter) {
   329  	if f {
   330  		r.Err = errors.New("unserializable payload")
   331  	}
   332  }
   333  
   334  func (failSer) DecodeBinary(w *io.BinReader) {}
   335  
   336  func newDummyBlock(height uint32, txCount int) *block.Block {
   337  	b := block.New(false)
   338  	b.Index = height
   339  	b.PrevHash = random.Uint256()
   340  	b.Timestamp = rand.Uint64()
   341  	b.Script.InvocationScript = random.Bytes(2)
   342  	b.Script.VerificationScript = random.Bytes(3)
   343  	b.Transactions = make([]*transaction.Transaction, txCount)
   344  	for i := range b.Transactions {
   345  		b.Transactions[i] = newDummyTx()
   346  	}
   347  	b.Hash()
   348  	return b
   349  }
   350  
   351  func newDummyTx() *transaction.Transaction {
   352  	tx := transaction.New(random.Bytes(100), 123)
   353  	tx.Signers = []transaction.Signer{{Account: random.Uint160()}}
   354  	tx.Scripts = []transaction.Witness{{InvocationScript: []byte{}, VerificationScript: []byte{}}}
   355  	tx.Size()
   356  	tx.Hash()
   357  	return tx
   358  }
   359  
   360  func testEncodeDecode(t *testing.T, cmd CommandType, p payload.Payload) *Message {
   361  	expected := NewMessage(cmd, p)
   362  	actual := &Message{}
   363  	testserdes.EncodeDecode(t, expected, actual)
   364  	return actual
   365  }
   366  
   367  func testEncodeDecodeFail(t *testing.T, cmd CommandType, p payload.Payload) *Message {
   368  	expected := NewMessage(cmd, p)
   369  	data, err := testserdes.Encode(expected)
   370  	require.NoError(t, err)
   371  
   372  	actual := &Message{}
   373  	require.Error(t, testserdes.Decode(data, actual))
   374  	return actual
   375  }