github.com/palisadeinc/bor@v0.0.0-20230615125219-ab7196213d15/eth/downloader/whitelist/service_test.go (about)

     1  package whitelist
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"math/big"
     7  	"reflect"
     8  	"sort"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/stretchr/testify/require"
    13  
    14  	"github.com/ethereum/go-ethereum/common"
    15  	"github.com/ethereum/go-ethereum/core/types"
    16  )
    17  
    18  // NewMockService creates a new mock whitelist service
    19  func NewMockService(maxCapacity uint, checkpointInterval uint64) *Service {
    20  	return &Service{
    21  		checkpointWhitelist: make(map[uint64]common.Hash),
    22  		checkpointOrder:     []uint64{},
    23  		maxCapacity:         maxCapacity,
    24  		checkpointInterval:  checkpointInterval,
    25  	}
    26  }
    27  
    28  // TestWhitelistCheckpoint checks the checkpoint whitelist map queue mechanism
    29  func TestWhitelistCheckpoint(t *testing.T) {
    30  	t.Parallel()
    31  
    32  	s := NewMockService(10, 10)
    33  	for i := 0; i < 10; i++ {
    34  		s.enqueueCheckpointWhitelist(uint64(i), common.Hash{})
    35  	}
    36  	require.Equal(t, s.length(), 10, "expected 10 items in whitelist")
    37  
    38  	s.enqueueCheckpointWhitelist(11, common.Hash{})
    39  	s.dequeueCheckpointWhitelist()
    40  	require.Equal(t, s.length(), 10, "expected 10 items in whitelist")
    41  }
    42  
    43  // TestIsValidPeer checks the IsValidPeer function in isolation
    44  // for different cases by providing a mock fetchHeadersByNumber function
    45  func TestIsValidPeer(t *testing.T) {
    46  	t.Parallel()
    47  
    48  	s := NewMockService(10, 10)
    49  
    50  	// case1: no checkpoint whitelist, should consider the chain as valid
    51  	res, err := s.IsValidPeer(nil, nil)
    52  	require.NoError(t, err, "expected no error")
    53  	require.Equal(t, res, true, "expected chain to be valid")
    54  
    55  	// add checkpoint entries and mock fetchHeadersByNumber function
    56  	s.ProcessCheckpoint(uint64(0), common.Hash{})
    57  	s.ProcessCheckpoint(uint64(1), common.Hash{})
    58  
    59  	require.Equal(t, s.length(), 2, "expected 2 items in whitelist")
    60  
    61  	// create a false function, returning absolutely nothing
    62  	falseFetchHeadersByNumber := func(number uint64, amount int, skip int, reverse bool) ([]*types.Header, []common.Hash, error) {
    63  		return nil, nil, nil
    64  	}
    65  
    66  	// case2: false fetchHeadersByNumber function provided, should consider the chain as invalid
    67  	// and throw `ErrNoRemoteCheckpoint` error
    68  	res, err = s.IsValidPeer(nil, falseFetchHeadersByNumber)
    69  	if err == nil {
    70  		t.Fatal("expected error, got nil")
    71  	}
    72  
    73  	if !errors.Is(err, ErrNoRemoteCheckpoint) {
    74  		t.Fatalf("expected error ErrNoRemoteCheckpoint, got %v", err)
    75  	}
    76  
    77  	require.Equal(t, res, false, "expected chain to be invalid")
    78  
    79  	// case3: correct fetchHeadersByNumber function provided, should consider the chain as valid
    80  	// create a mock function, returning a the required header
    81  	fetchHeadersByNumber := func(number uint64, _ int, _ int, _ bool) ([]*types.Header, []common.Hash, error) {
    82  		hash := common.Hash{}
    83  		header := types.Header{Number: big.NewInt(0)}
    84  
    85  		switch number {
    86  		case 0:
    87  			return []*types.Header{&header}, []common.Hash{hash}, nil
    88  		case 1:
    89  			header.Number = big.NewInt(1)
    90  			return []*types.Header{&header}, []common.Hash{hash}, nil
    91  		case 2:
    92  			header.Number = big.NewInt(1) // sending wrong header for misamatch
    93  			return []*types.Header{&header}, []common.Hash{hash}, nil
    94  		default:
    95  			return nil, nil, errors.New("invalid number")
    96  		}
    97  	}
    98  
    99  	res, err = s.IsValidPeer(nil, fetchHeadersByNumber)
   100  	require.NoError(t, err, "expected no error")
   101  	require.Equal(t, res, true, "expected chain to be valid")
   102  
   103  	// add one more checkpoint whitelist entry
   104  	s.ProcessCheckpoint(uint64(2), common.Hash{})
   105  	require.Equal(t, s.length(), 3, "expected 3 items in whitelist")
   106  
   107  	// case4: correct fetchHeadersByNumber function provided with wrong header
   108  	// for block number 2. Should consider the chain as invalid and throw an error
   109  	res, err = s.IsValidPeer(nil, fetchHeadersByNumber)
   110  	require.Equal(t, err, ErrCheckpointMismatch, "expected checkpoint mismatch error")
   111  	require.Equal(t, res, false, "expected chain to be invalid")
   112  }
   113  
   114  // TestIsValidChain checks the IsValidChain function in isolation
   115  // for different cases by providing a mock current header and chain
   116  func TestIsValidChain(t *testing.T) {
   117  	t.Parallel()
   118  
   119  	s := NewMockService(10, 10)
   120  	chainA := createMockChain(1, 20) // A1->A2...A19->A20
   121  	// case1: no checkpoint whitelist, should consider the chain as valid
   122  	res, err := s.IsValidChain(nil, chainA)
   123  	require.Equal(t, res, true, "expected chain to be valid")
   124  	require.Equal(t, err, nil, "expected error to be nil")
   125  
   126  	tempChain := createMockChain(21, 22) // A21->A22
   127  
   128  	// add mock checkpoint entries
   129  	s.ProcessCheckpoint(tempChain[0].Number.Uint64(), tempChain[0].Hash())
   130  	s.ProcessCheckpoint(tempChain[1].Number.Uint64(), tempChain[1].Hash())
   131  
   132  	require.Equal(t, s.length(), 2, "expected 2 items in whitelist")
   133  
   134  	// case2: We're behind the oldest whitelisted block entry, should consider
   135  	// the chain as valid as we're still far behind the latest blocks
   136  	res, err = s.IsValidChain(chainA[len(chainA)-1], chainA)
   137  	require.Equal(t, res, true, "expected chain to be valid")
   138  	require.Equal(t, err, nil, "expected error to be nil")
   139  
   140  	// Clear checkpoint whitelist and add blocks A5 and A15 in whitelist
   141  	s.PurgeCheckpointWhitelist()
   142  	s.ProcessCheckpoint(chainA[5].Number.Uint64(), chainA[5].Hash())
   143  	s.ProcessCheckpoint(chainA[15].Number.Uint64(), chainA[15].Hash())
   144  
   145  	require.Equal(t, s.length(), 2, "expected 2 items in whitelist")
   146  
   147  	// case3: Try importing a past chain having valid checkpoint, should
   148  	// consider the chain as valid
   149  	res, err = s.IsValidChain(chainA[len(chainA)-1], chainA)
   150  	require.Equal(t, res, true, "expected chain to be valid")
   151  	require.Equal(t, err, nil, "expected error to be nil")
   152  
   153  	// Clear checkpoint whitelist and mock blocks in whitelist
   154  	tempChain = createMockChain(20, 20) // A20
   155  
   156  	s.PurgeCheckpointWhitelist()
   157  	s.ProcessCheckpoint(tempChain[0].Number.Uint64(), tempChain[0].Hash())
   158  
   159  	require.Equal(t, s.length(), 1, "expected 1 items in whitelist")
   160  
   161  	// case4: Try importing a past chain having invalid checkpoint
   162  	res, _ = s.IsValidChain(chainA[len(chainA)-1], chainA)
   163  	require.Equal(t, res, false, "expected chain to be invalid")
   164  	// Not checking error here because we return nil in case of checkpoint mismatch
   165  
   166  	// create a future chain to be imported of length <= `checkpointInterval`
   167  	chainB := createMockChain(21, 30) // B21->B22...B29->B30
   168  
   169  	// case5: Try importing a future chain (1)
   170  	res, err = s.IsValidChain(chainA[len(chainA)-1], chainB)
   171  	require.Equal(t, res, true, "expected chain to be valid")
   172  	require.Equal(t, err, nil, "expected error to be nil")
   173  
   174  	// create a future chain to be imported of length > `checkpointInterval`
   175  	chainB = createMockChain(21, 40) // C21->C22...C39->C40
   176  
   177  	// Note: Earlier, it used to reject future chains longer than some threshold.
   178  	// That check is removed for now.
   179  
   180  	// case6: Try importing a future chain (2)
   181  	res, err = s.IsValidChain(chainA[len(chainA)-1], chainB)
   182  	require.Equal(t, res, true, "expected chain to be valid")
   183  	require.Equal(t, err, nil, "expected error to be nil")
   184  }
   185  
   186  func TestSplitChain(t *testing.T) {
   187  	t.Parallel()
   188  
   189  	type Result struct {
   190  		pastStart    uint64
   191  		pastEnd      uint64
   192  		futureStart  uint64
   193  		futureEnd    uint64
   194  		pastLength   int
   195  		futureLength int
   196  	}
   197  
   198  	// Current chain is at block: X
   199  	// Incoming chain is represented as [N, M]
   200  	testCases := []struct {
   201  		name    string
   202  		current uint64
   203  		chain   []*types.Header
   204  		result  Result
   205  	}{
   206  		{name: "X = 10, N = 11, M = 20", current: uint64(10), chain: createMockChain(11, 20), result: Result{futureStart: 11, futureEnd: 20, futureLength: 10}},
   207  		{name: "X = 10, N = 13, M = 20", current: uint64(10), chain: createMockChain(13, 20), result: Result{futureStart: 13, futureEnd: 20, futureLength: 8}},
   208  		{name: "X = 10, N = 2, M = 10", current: uint64(10), chain: createMockChain(2, 10), result: Result{pastStart: 2, pastEnd: 10, pastLength: 9}},
   209  		{name: "X = 10, N = 2, M = 9", current: uint64(10), chain: createMockChain(2, 9), result: Result{pastStart: 2, pastEnd: 9, pastLength: 8}},
   210  		{name: "X = 10, N = 2, M = 8", current: uint64(10), chain: createMockChain(2, 8), result: Result{pastStart: 2, pastEnd: 8, pastLength: 7}},
   211  		{name: "X = 10, N = 5, M = 15", current: uint64(10), chain: createMockChain(5, 15), result: Result{pastStart: 5, pastEnd: 10, pastLength: 6, futureStart: 11, futureEnd: 15, futureLength: 5}},
   212  		{name: "X = 10, N = 10, M = 20", current: uint64(10), chain: createMockChain(10, 20), result: Result{pastStart: 10, pastEnd: 10, pastLength: 1, futureStart: 11, futureEnd: 20, futureLength: 10}},
   213  	}
   214  	for _, tc := range testCases {
   215  		tc := tc
   216  		t.Run(tc.name, func(t *testing.T) {
   217  			t.Parallel()
   218  			past, future := splitChain(tc.current, tc.chain)
   219  			require.Equal(t, len(past), tc.result.pastLength)
   220  			require.Equal(t, len(future), tc.result.futureLength)
   221  
   222  			if len(past) > 0 {
   223  				// Check if we have expected block/s
   224  				require.Equal(t, past[0].Number.Uint64(), tc.result.pastStart)
   225  				require.Equal(t, past[len(past)-1].Number.Uint64(), tc.result.pastEnd)
   226  			}
   227  
   228  			if len(future) > 0 {
   229  				// Check if we have expected block/s
   230  				require.Equal(t, future[0].Number.Uint64(), tc.result.futureStart)
   231  				require.Equal(t, future[len(future)-1].Number.Uint64(), tc.result.futureEnd)
   232  			}
   233  		})
   234  	}
   235  }
   236  
   237  //nolint:gocognit
   238  func TestSplitChainProperties(t *testing.T) {
   239  	t.Parallel()
   240  
   241  	// Current chain is at block: X
   242  	// Incoming chain is represented as [N, M]
   243  
   244  	currentChain := []int{0, 1, 2, 3, 10, 100} // blocks starting from genesis
   245  	blockDiffs := []int{0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 90, 100, 101, 102}
   246  
   247  	caseParams := make(map[int]map[int]map[int]struct{}) // X -> N -> M
   248  
   249  	for _, current := range currentChain {
   250  		// past cases only + past to current
   251  		for _, diff := range blockDiffs {
   252  			from := current - diff
   253  
   254  			// use int type for everything to not care about underflow
   255  			if from < 0 {
   256  				continue
   257  			}
   258  
   259  			for _, diff := range blockDiffs {
   260  				to := current - diff
   261  
   262  				if to >= from {
   263  					addTestCaseParams(caseParams, current, from, to)
   264  				}
   265  			}
   266  		}
   267  
   268  		// future only + current to future
   269  		for _, diff := range blockDiffs {
   270  			from := current + diff
   271  
   272  			if from < 0 {
   273  				continue
   274  			}
   275  
   276  			for _, diff := range blockDiffs {
   277  				to := current + diff
   278  
   279  				if to >= from {
   280  					addTestCaseParams(caseParams, current, from, to)
   281  				}
   282  			}
   283  		}
   284  
   285  		// past-current-future
   286  		for _, diff := range blockDiffs {
   287  			from := current - diff
   288  
   289  			if from < 0 {
   290  				continue
   291  			}
   292  
   293  			for _, diff := range blockDiffs {
   294  				to := current + diff
   295  
   296  				if to >= from {
   297  					addTestCaseParams(caseParams, current, from, to)
   298  				}
   299  			}
   300  		}
   301  	}
   302  
   303  	type testCase struct {
   304  		current     int
   305  		remoteStart int
   306  		remoteEnd   int
   307  	}
   308  
   309  	var ts []testCase
   310  
   311  	// X -> N -> M
   312  	for x, nm := range caseParams {
   313  		for n, mMap := range nm {
   314  			for m := range mMap {
   315  				ts = append(ts, testCase{x, n, m})
   316  			}
   317  		}
   318  	}
   319  
   320  	//nolint:paralleltest
   321  	for i, tc := range ts {
   322  		tc := tc
   323  
   324  		name := fmt.Sprintf("test case: index = %d, X = %d, N = %d, M = %d", i, tc.current, tc.remoteStart, tc.remoteEnd)
   325  
   326  		t.Run(name, func(t *testing.T) {
   327  			t.Parallel()
   328  
   329  			chain := createMockChain(uint64(tc.remoteStart), uint64(tc.remoteEnd))
   330  
   331  			past, future := splitChain(uint64(tc.current), chain)
   332  
   333  			// properties
   334  			if len(past) > 0 {
   335  				// Check if the chain is ordered
   336  				isOrdered := sort.SliceIsSorted(past, func(i, j int) bool {
   337  					return past[i].Number.Uint64() < past[j].Number.Uint64()
   338  				})
   339  
   340  				require.True(t, isOrdered, "an ordered past chain expected: %v", past)
   341  
   342  				isSequential := sort.SliceIsSorted(past, func(i, j int) bool {
   343  					return past[i].Number.Uint64() == past[j].Number.Uint64()-1
   344  				})
   345  
   346  				require.True(t, isSequential, "a sequential past chain expected: %v", past)
   347  
   348  				// Check if current block >= past chain's last block
   349  				require.Equal(t, past[len(past)-1].Number.Uint64() <= uint64(tc.current), true)
   350  			}
   351  
   352  			if len(future) > 0 {
   353  				// Check if the chain is ordered
   354  				isOrdered := sort.SliceIsSorted(future, func(i, j int) bool {
   355  					return future[i].Number.Uint64() < future[j].Number.Uint64()
   356  				})
   357  
   358  				require.True(t, isOrdered, "an ordered future chain expected: %v", future)
   359  
   360  				isSequential := sort.SliceIsSorted(future, func(i, j int) bool {
   361  					return future[i].Number.Uint64() == future[j].Number.Uint64()-1
   362  				})
   363  
   364  				require.True(t, isSequential, "a sequential future chain expected: %v", future)
   365  
   366  				// Check if future chain's first block > current block
   367  				require.Equal(t, future[len(future)-1].Number.Uint64() > uint64(tc.current), true)
   368  			}
   369  
   370  			// Check if both chains are continuous
   371  			if len(past) > 0 && len(future) > 0 {
   372  				require.Equal(t, past[len(past)-1].Number.Uint64(), future[0].Number.Uint64()-1)
   373  			}
   374  
   375  			// Check if we get the original chain on appending both
   376  			gotChain := append(past, future...)
   377  			require.Equal(t, reflect.DeepEqual(gotChain, chain), true)
   378  		})
   379  	}
   380  }
   381  
   382  // createMockChain returns a chain with dummy headers
   383  // starting from `start` to `end` (inclusive)
   384  func createMockChain(start, end uint64) []*types.Header {
   385  	var (
   386  		i     uint64
   387  		idx   uint64
   388  		chain []*types.Header = make([]*types.Header, end-start+1)
   389  	)
   390  
   391  	for i = start; i <= end; i++ {
   392  		header := &types.Header{
   393  			Number: big.NewInt(int64(i)),
   394  			Time:   uint64(time.Now().UnixMicro()) + i,
   395  		}
   396  		chain[idx] = header
   397  		idx++
   398  	}
   399  
   400  	return chain
   401  }
   402  
   403  // mXNM should be initialized
   404  func addTestCaseParams(mXNM map[int]map[int]map[int]struct{}, x, n, m int) {
   405  	//nolint:ineffassign
   406  	mNM, ok := mXNM[x]
   407  	if !ok {
   408  		mNM = make(map[int]map[int]struct{})
   409  		mXNM[x] = mNM
   410  	}
   411  
   412  	//nolint:ineffassign
   413  	_, ok = mNM[n]
   414  	if !ok {
   415  		mM := make(map[int]struct{})
   416  		mNM[n] = mM
   417  	}
   418  
   419  	mXNM[x][n][m] = struct{}{}
   420  }