github.com/klaytn/klaytn@v1.12.1/node/cn/snap/handler_test.go (about)

     1  // Copyright 2022 The klaytn Authors
     2  // This file is part of the klaytn library.
     3  //
     4  // The klaytn library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The klaytn library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the klaytn library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package snap
    18  
    19  import (
    20  	"errors"
    21  	"math/big"
    22  	"strings"
    23  	"testing"
    24  
    25  	"github.com/klaytn/klaytn/blockchain/state"
    26  	"github.com/klaytn/klaytn/blockchain/types/account"
    27  	"github.com/klaytn/klaytn/common"
    28  	"github.com/klaytn/klaytn/networks/p2p"
    29  	"github.com/klaytn/klaytn/rlp"
    30  	"github.com/klaytn/klaytn/snapshot"
    31  	"github.com/klaytn/klaytn/storage/database"
    32  	"github.com/klaytn/klaytn/storage/statedb"
    33  	"github.com/stretchr/testify/assert"
    34  )
    35  
    36  type testMsgRW struct {
    37  	reader func() (p2p.Msg, error)
    38  	writer func(msg p2p.Msg) error
    39  }
    40  
    41  func (rw *testMsgRW) ReadMsg() (p2p.Msg, error)  { return rw.reader() }
    42  func (rw *testMsgRW) WriteMsg(msg p2p.Msg) error { return nil }
    43  
    44  type testDownloader struct{}
    45  
    46  func (d *testDownloader) DeliverSnapPacket(peer *Peer, packet Packet) error { return nil }
    47  
    48  type testSnapshotReader struct {
    49  	db   state.Database
    50  	snap *snapshot.Tree
    51  }
    52  
    53  type testKV struct {
    54  	k []byte
    55  	v []byte
    56  }
    57  
    58  func NewTestSnapshotReader(items []*testKV) (*testSnapshotReader, common.Hash) {
    59  	memdb := database.NewMemoryDBManager()
    60  	db := state.NewDatabase(memdb)
    61  	trie, _ := statedb.NewTrie(common.Hash{}, db.TrieDB(), nil)
    62  	for _, kv := range items {
    63  		trie.Update(kv.k, kv.v)
    64  	}
    65  	root, _ := trie.Commit(nil)
    66  	db.TrieDB().Commit(root, false, 0)
    67  
    68  	snap, _ := snapshot.New(memdb, db.TrieDB(), 256, root, false, true, false)
    69  
    70  	return &testSnapshotReader{
    71  		db,
    72  		snap,
    73  	}, root
    74  }
    75  
    76  func (r *testSnapshotReader) StateCache() state.Database {
    77  	return r.db
    78  }
    79  
    80  func (r *testSnapshotReader) Snapshots() *snapshot.Tree {
    81  	return r.snap
    82  }
    83  
    84  func (r *testSnapshotReader) ContractCode(hash common.Hash) ([]byte, error) {
    85  	return r.db.ContractCode(hash)
    86  }
    87  
    88  func (r *testSnapshotReader) ContractCodeWithPrefix(hash common.Hash) ([]byte, error) {
    89  	return nil, nil
    90  }
    91  
    92  func createMsg(msgcode uint64, data interface{}) (p2p.Msg, error) {
    93  	size, r, err := rlp.EncodeToReader(data)
    94  	if err != nil {
    95  		return p2p.Msg{}, err
    96  	}
    97  	return p2p.Msg{Code: msgcode, Size: uint32(size), Payload: r}, nil
    98  }
    99  
   100  func createAccountRangeReqMsg(root common.Hash) p2p.Msg {
   101  	msg, _ := createMsg(GetAccountRangeMsg, &GetAccountRangePacket{
   102  		ID:     1,
   103  		Root:   root,
   104  		Origin: common.Hash{},
   105  		Limit:  common.Hash{},
   106  		Bytes:  softResponseLimit,
   107  	})
   108  	return msg
   109  }
   110  
   111  func mockPeer(msg p2p.Msg) *Peer {
   112  	mockPeer := NewFakePeer(1, common.BytesToHash([]byte{0x1}).String(), &testMsgRW{reader: func() (p2p.Msg, error) {
   113  		return msg, nil
   114  	}})
   115  	return mockPeer
   116  }
   117  
   118  func TestMessageDecoding(t *testing.T) {
   119  	var (
   120  		msg p2p.Msg
   121  		err error
   122  	)
   123  	msg, err = createMsg(GetAccountRangeMsg, &GetAccountRangePacket{
   124  		ID:     0,
   125  		Root:   common.Hash{},
   126  		Origin: common.Hash{},
   127  		Limit:  common.Hash{},
   128  		Bytes:  0,
   129  	})
   130  	assert.NoError(t, err)
   131  	var req1 GetAccountRangePacket
   132  	assert.NoError(t, msg.Decode(&req1))
   133  
   134  	msg, err = createMsg(AccountRangeMsg, &AccountRangePacket{
   135  		ID:       0,
   136  		Accounts: nil,
   137  		Proof:    nil,
   138  	})
   139  	assert.NoError(t, err)
   140  	var req2 AccountRangePacket
   141  	assert.NoError(t, msg.Decode(&req2))
   142  
   143  	msg, err = createMsg(GetStorageRangesMsg, &GetStorageRangesPacket{
   144  		ID:       0,
   145  		Root:     common.Hash{},
   146  		Accounts: nil,
   147  		Origin:   nil,
   148  		Limit:    nil,
   149  		Bytes:    0,
   150  	})
   151  	assert.NoError(t, err)
   152  	var req3 GetStorageRangesPacket
   153  	assert.NoError(t, msg.Decode(&req3))
   154  
   155  	msg, err = createMsg(StorageRangesMsg, &StorageRangesPacket{
   156  		ID:    0,
   157  		Slots: nil,
   158  		Proof: nil,
   159  	})
   160  	assert.NoError(t, err)
   161  	var req4 StorageRangesPacket
   162  	assert.NoError(t, msg.Decode(&req4))
   163  
   164  	msg, err = createMsg(GetByteCodesMsg, &GetByteCodesPacket{
   165  		ID:     0,
   166  		Hashes: nil,
   167  		Bytes:  0,
   168  	})
   169  	assert.NoError(t, err)
   170  	var req5 GetByteCodesPacket
   171  	assert.NoError(t, msg.Decode(&req5))
   172  
   173  	msg, err = createMsg(ByteCodesMsg, &ByteCodesPacket{
   174  		ID:    0,
   175  		Codes: nil,
   176  	})
   177  	assert.NoError(t, err)
   178  	var req6 ByteCodesPacket
   179  	assert.NoError(t, msg.Decode(&req6))
   180  
   181  	msg, err = createMsg(GetTrieNodesMsg, &GetTrieNodesPacket{
   182  		ID:    0,
   183  		Root:  common.Hash{},
   184  		Paths: nil,
   185  		Bytes: 0,
   186  	})
   187  	assert.NoError(t, err)
   188  	var req7 GetTrieNodesPacket
   189  	assert.NoError(t, msg.Decode(&req7))
   190  
   191  	msg, err = createMsg(TrieNodesMsg, &TrieNodesPacket{
   192  		ID:    0,
   193  		Nodes: nil,
   194  	})
   195  	assert.NoError(t, err)
   196  	var req8 TrieNodesPacket
   197  	assert.NoError(t, msg.Decode(&req8))
   198  }
   199  
   200  func TestHandleMessage_ReadMsgErr(t *testing.T) {
   201  	reader := &testSnapshotReader{}
   202  
   203  	// create test message
   204  	msg, _ := createMsg(GetAccountRangeMsg, []byte{0x1})
   205  	msg.Size = maxMessageSize + 1
   206  	peer := mockPeer(msg)
   207  	testErr := errors.New("test error")
   208  	peer.rw = &testMsgRW{reader: func() (p2p.Msg, error) { return p2p.Msg{}, testErr }}
   209  
   210  	// failed to handle message due to read msg error
   211  	err := HandleMessage(reader, &testDownloader{}, peer)
   212  	assert.Equal(t, err, testErr)
   213  }
   214  
   215  func TestHandleMessage_LargeMessageErr(t *testing.T) {
   216  	reader := &testSnapshotReader{}
   217  
   218  	// create test message
   219  	msg, _ := createMsg(GetAccountRangeMsg, []byte{0x1})
   220  	msg.Size = maxMessageSize + 1
   221  	peer := mockPeer(msg)
   222  
   223  	// failed to handle message due to too large message size
   224  	err := HandleMessage(reader, &testDownloader{}, peer)
   225  	assert.True(t, strings.Contains(err.Error(), errMsgTooLarge.Error()))
   226  }
   227  
   228  func TestHandleMessage_LargeMessageInvalidMsgErr(t *testing.T) {
   229  	reader := &testSnapshotReader{}
   230  
   231  	// create test message
   232  	msg, _ := createMsg(0x08, []byte{0x1})
   233  	peer := mockPeer(msg)
   234  
   235  	// failed to handle message due to too large message size
   236  	err := HandleMessage(reader, &testDownloader{}, peer)
   237  	assert.True(t, strings.Contains(err.Error(), errInvalidMsgCode.Error()))
   238  }
   239  
   240  func TestHandleMessage_GetAccountRange_EmptyItem(t *testing.T) {
   241  	items := []*testKV{}
   242  	reader, root := NewTestSnapshotReader(items)
   243  
   244  	err := HandleMessage(reader, &testDownloader{}, mockPeer(createAccountRangeReqMsg(root)))
   245  	assert.NoError(t, err)
   246  }
   247  
   248  func TestHandleMessage_Success(t *testing.T) {
   249  	items := []*testKV{}
   250  	for i := uint64(1); i <= 100; i++ {
   251  		acc, _ := genExternallyOwnedAccount(1, big.NewInt(1))
   252  		serializer := account.NewAccountSerializerWithAccount(acc)
   253  		bytes, _ := rlp.EncodeToBytes(serializer)
   254  		items = append(items, &testKV{key32(i), bytes})
   255  	}
   256  
   257  	reader, root := NewTestSnapshotReader(items)
   258  	var (
   259  		msgs []p2p.Msg
   260  		msg  p2p.Msg
   261  		err  error
   262  	)
   263  
   264  	msg, err = createMsg(GetAccountRangeMsg, &GetAccountRangePacket{ID: 1, Root: root})
   265  	assert.NoError(t, err)
   266  	msgs = append(msgs, msg)
   267  
   268  	msg, err = createMsg(AccountRangeMsg, &AccountRangePacket{})
   269  	assert.NoError(t, err)
   270  	msgs = append(msgs, msg)
   271  
   272  	msg, err = createMsg(GetStorageRangesMsg, &GetStorageRangesPacket{Root: root})
   273  	assert.NoError(t, err)
   274  	msgs = append(msgs, msg)
   275  
   276  	msg, err = createMsg(StorageRangesMsg, &StorageRangesPacket{})
   277  	assert.NoError(t, err)
   278  	msgs = append(msgs, msg)
   279  
   280  	msg, err = createMsg(GetByteCodesMsg, &GetByteCodesPacket{})
   281  	assert.NoError(t, err)
   282  	msgs = append(msgs, msg)
   283  
   284  	msg, err = createMsg(ByteCodesMsg, &ByteCodesPacket{})
   285  	assert.NoError(t, err)
   286  	msgs = append(msgs, msg)
   287  
   288  	msg, err = createMsg(GetTrieNodesMsg, &GetTrieNodesPacket{Root: root})
   289  	assert.NoError(t, err)
   290  	msgs = append(msgs, msg)
   291  
   292  	msg, err = createMsg(TrieNodesMsg, &TrieNodesPacket{})
   293  	assert.NoError(t, err)
   294  	msgs = append(msgs, msg)
   295  
   296  	for _, msg := range msgs {
   297  		err := HandleMessage(reader, &testDownloader{}, mockPeer(msg))
   298  		assert.NoError(t, err)
   299  	}
   300  }