github.com/palisadeinc/bor@v0.0.0-20230615125219-ab7196213d15/eth/downloader/whitelist/service_test.go (about) 1 package whitelist 2 3 import ( 4 "errors" 5 "fmt" 6 "math/big" 7 "reflect" 8 "sort" 9 "testing" 10 "time" 11 12 "github.com/stretchr/testify/require" 13 14 "github.com/ethereum/go-ethereum/common" 15 "github.com/ethereum/go-ethereum/core/types" 16 ) 17 18 // NewMockService creates a new mock whitelist service 19 func NewMockService(maxCapacity uint, checkpointInterval uint64) *Service { 20 return &Service{ 21 checkpointWhitelist: make(map[uint64]common.Hash), 22 checkpointOrder: []uint64{}, 23 maxCapacity: maxCapacity, 24 checkpointInterval: checkpointInterval, 25 } 26 } 27 28 // TestWhitelistCheckpoint checks the checkpoint whitelist map queue mechanism 29 func TestWhitelistCheckpoint(t *testing.T) { 30 t.Parallel() 31 32 s := NewMockService(10, 10) 33 for i := 0; i < 10; i++ { 34 s.enqueueCheckpointWhitelist(uint64(i), common.Hash{}) 35 } 36 require.Equal(t, s.length(), 10, "expected 10 items in whitelist") 37 38 s.enqueueCheckpointWhitelist(11, common.Hash{}) 39 s.dequeueCheckpointWhitelist() 40 require.Equal(t, s.length(), 10, "expected 10 items in whitelist") 41 } 42 43 // TestIsValidPeer checks the IsValidPeer function in isolation 44 // for different cases by providing a mock fetchHeadersByNumber function 45 func TestIsValidPeer(t *testing.T) { 46 t.Parallel() 47 48 s := NewMockService(10, 10) 49 50 // case1: no checkpoint whitelist, should consider the chain as valid 51 res, err := s.IsValidPeer(nil, nil) 52 require.NoError(t, err, "expected no error") 53 require.Equal(t, res, true, "expected chain to be valid") 54 55 // add checkpoint entries and mock fetchHeadersByNumber function 56 s.ProcessCheckpoint(uint64(0), common.Hash{}) 57 s.ProcessCheckpoint(uint64(1), common.Hash{}) 58 59 require.Equal(t, s.length(), 2, "expected 2 items in whitelist") 60 61 // create a false function, returning absolutely nothing 62 falseFetchHeadersByNumber := func(number uint64, amount int, skip int, reverse bool) ([]*types.Header, []common.Hash, error) { 63 return nil, nil, nil 64 } 65 66 // case2: false fetchHeadersByNumber function provided, should consider the chain as invalid 67 // and throw `ErrNoRemoteCheckpoint` error 68 res, err = s.IsValidPeer(nil, falseFetchHeadersByNumber) 69 if err == nil { 70 t.Fatal("expected error, got nil") 71 } 72 73 if !errors.Is(err, ErrNoRemoteCheckpoint) { 74 t.Fatalf("expected error ErrNoRemoteCheckpoint, got %v", err) 75 } 76 77 require.Equal(t, res, false, "expected chain to be invalid") 78 79 // case3: correct fetchHeadersByNumber function provided, should consider the chain as valid 80 // create a mock function, returning a the required header 81 fetchHeadersByNumber := func(number uint64, _ int, _ int, _ bool) ([]*types.Header, []common.Hash, error) { 82 hash := common.Hash{} 83 header := types.Header{Number: big.NewInt(0)} 84 85 switch number { 86 case 0: 87 return []*types.Header{&header}, []common.Hash{hash}, nil 88 case 1: 89 header.Number = big.NewInt(1) 90 return []*types.Header{&header}, []common.Hash{hash}, nil 91 case 2: 92 header.Number = big.NewInt(1) // sending wrong header for misamatch 93 return []*types.Header{&header}, []common.Hash{hash}, nil 94 default: 95 return nil, nil, errors.New("invalid number") 96 } 97 } 98 99 res, err = s.IsValidPeer(nil, fetchHeadersByNumber) 100 require.NoError(t, err, "expected no error") 101 require.Equal(t, res, true, "expected chain to be valid") 102 103 // add one more checkpoint whitelist entry 104 s.ProcessCheckpoint(uint64(2), common.Hash{}) 105 require.Equal(t, s.length(), 3, "expected 3 items in whitelist") 106 107 // case4: correct fetchHeadersByNumber function provided with wrong header 108 // for block number 2. Should consider the chain as invalid and throw an error 109 res, err = s.IsValidPeer(nil, fetchHeadersByNumber) 110 require.Equal(t, err, ErrCheckpointMismatch, "expected checkpoint mismatch error") 111 require.Equal(t, res, false, "expected chain to be invalid") 112 } 113 114 // TestIsValidChain checks the IsValidChain function in isolation 115 // for different cases by providing a mock current header and chain 116 func TestIsValidChain(t *testing.T) { 117 t.Parallel() 118 119 s := NewMockService(10, 10) 120 chainA := createMockChain(1, 20) // A1->A2...A19->A20 121 // case1: no checkpoint whitelist, should consider the chain as valid 122 res, err := s.IsValidChain(nil, chainA) 123 require.Equal(t, res, true, "expected chain to be valid") 124 require.Equal(t, err, nil, "expected error to be nil") 125 126 tempChain := createMockChain(21, 22) // A21->A22 127 128 // add mock checkpoint entries 129 s.ProcessCheckpoint(tempChain[0].Number.Uint64(), tempChain[0].Hash()) 130 s.ProcessCheckpoint(tempChain[1].Number.Uint64(), tempChain[1].Hash()) 131 132 require.Equal(t, s.length(), 2, "expected 2 items in whitelist") 133 134 // case2: We're behind the oldest whitelisted block entry, should consider 135 // the chain as valid as we're still far behind the latest blocks 136 res, err = s.IsValidChain(chainA[len(chainA)-1], chainA) 137 require.Equal(t, res, true, "expected chain to be valid") 138 require.Equal(t, err, nil, "expected error to be nil") 139 140 // Clear checkpoint whitelist and add blocks A5 and A15 in whitelist 141 s.PurgeCheckpointWhitelist() 142 s.ProcessCheckpoint(chainA[5].Number.Uint64(), chainA[5].Hash()) 143 s.ProcessCheckpoint(chainA[15].Number.Uint64(), chainA[15].Hash()) 144 145 require.Equal(t, s.length(), 2, "expected 2 items in whitelist") 146 147 // case3: Try importing a past chain having valid checkpoint, should 148 // consider the chain as valid 149 res, err = s.IsValidChain(chainA[len(chainA)-1], chainA) 150 require.Equal(t, res, true, "expected chain to be valid") 151 require.Equal(t, err, nil, "expected error to be nil") 152 153 // Clear checkpoint whitelist and mock blocks in whitelist 154 tempChain = createMockChain(20, 20) // A20 155 156 s.PurgeCheckpointWhitelist() 157 s.ProcessCheckpoint(tempChain[0].Number.Uint64(), tempChain[0].Hash()) 158 159 require.Equal(t, s.length(), 1, "expected 1 items in whitelist") 160 161 // case4: Try importing a past chain having invalid checkpoint 162 res, _ = s.IsValidChain(chainA[len(chainA)-1], chainA) 163 require.Equal(t, res, false, "expected chain to be invalid") 164 // Not checking error here because we return nil in case of checkpoint mismatch 165 166 // create a future chain to be imported of length <= `checkpointInterval` 167 chainB := createMockChain(21, 30) // B21->B22...B29->B30 168 169 // case5: Try importing a future chain (1) 170 res, err = s.IsValidChain(chainA[len(chainA)-1], chainB) 171 require.Equal(t, res, true, "expected chain to be valid") 172 require.Equal(t, err, nil, "expected error to be nil") 173 174 // create a future chain to be imported of length > `checkpointInterval` 175 chainB = createMockChain(21, 40) // C21->C22...C39->C40 176 177 // Note: Earlier, it used to reject future chains longer than some threshold. 178 // That check is removed for now. 179 180 // case6: Try importing a future chain (2) 181 res, err = s.IsValidChain(chainA[len(chainA)-1], chainB) 182 require.Equal(t, res, true, "expected chain to be valid") 183 require.Equal(t, err, nil, "expected error to be nil") 184 } 185 186 func TestSplitChain(t *testing.T) { 187 t.Parallel() 188 189 type Result struct { 190 pastStart uint64 191 pastEnd uint64 192 futureStart uint64 193 futureEnd uint64 194 pastLength int 195 futureLength int 196 } 197 198 // Current chain is at block: X 199 // Incoming chain is represented as [N, M] 200 testCases := []struct { 201 name string 202 current uint64 203 chain []*types.Header 204 result Result 205 }{ 206 {name: "X = 10, N = 11, M = 20", current: uint64(10), chain: createMockChain(11, 20), result: Result{futureStart: 11, futureEnd: 20, futureLength: 10}}, 207 {name: "X = 10, N = 13, M = 20", current: uint64(10), chain: createMockChain(13, 20), result: Result{futureStart: 13, futureEnd: 20, futureLength: 8}}, 208 {name: "X = 10, N = 2, M = 10", current: uint64(10), chain: createMockChain(2, 10), result: Result{pastStart: 2, pastEnd: 10, pastLength: 9}}, 209 {name: "X = 10, N = 2, M = 9", current: uint64(10), chain: createMockChain(2, 9), result: Result{pastStart: 2, pastEnd: 9, pastLength: 8}}, 210 {name: "X = 10, N = 2, M = 8", current: uint64(10), chain: createMockChain(2, 8), result: Result{pastStart: 2, pastEnd: 8, pastLength: 7}}, 211 {name: "X = 10, N = 5, M = 15", current: uint64(10), chain: createMockChain(5, 15), result: Result{pastStart: 5, pastEnd: 10, pastLength: 6, futureStart: 11, futureEnd: 15, futureLength: 5}}, 212 {name: "X = 10, N = 10, M = 20", current: uint64(10), chain: createMockChain(10, 20), result: Result{pastStart: 10, pastEnd: 10, pastLength: 1, futureStart: 11, futureEnd: 20, futureLength: 10}}, 213 } 214 for _, tc := range testCases { 215 tc := tc 216 t.Run(tc.name, func(t *testing.T) { 217 t.Parallel() 218 past, future := splitChain(tc.current, tc.chain) 219 require.Equal(t, len(past), tc.result.pastLength) 220 require.Equal(t, len(future), tc.result.futureLength) 221 222 if len(past) > 0 { 223 // Check if we have expected block/s 224 require.Equal(t, past[0].Number.Uint64(), tc.result.pastStart) 225 require.Equal(t, past[len(past)-1].Number.Uint64(), tc.result.pastEnd) 226 } 227 228 if len(future) > 0 { 229 // Check if we have expected block/s 230 require.Equal(t, future[0].Number.Uint64(), tc.result.futureStart) 231 require.Equal(t, future[len(future)-1].Number.Uint64(), tc.result.futureEnd) 232 } 233 }) 234 } 235 } 236 237 //nolint:gocognit 238 func TestSplitChainProperties(t *testing.T) { 239 t.Parallel() 240 241 // Current chain is at block: X 242 // Incoming chain is represented as [N, M] 243 244 currentChain := []int{0, 1, 2, 3, 10, 100} // blocks starting from genesis 245 blockDiffs := []int{0, 1, 2, 3, 4, 5, 9, 10, 11, 12, 90, 100, 101, 102} 246 247 caseParams := make(map[int]map[int]map[int]struct{}) // X -> N -> M 248 249 for _, current := range currentChain { 250 // past cases only + past to current 251 for _, diff := range blockDiffs { 252 from := current - diff 253 254 // use int type for everything to not care about underflow 255 if from < 0 { 256 continue 257 } 258 259 for _, diff := range blockDiffs { 260 to := current - diff 261 262 if to >= from { 263 addTestCaseParams(caseParams, current, from, to) 264 } 265 } 266 } 267 268 // future only + current to future 269 for _, diff := range blockDiffs { 270 from := current + diff 271 272 if from < 0 { 273 continue 274 } 275 276 for _, diff := range blockDiffs { 277 to := current + diff 278 279 if to >= from { 280 addTestCaseParams(caseParams, current, from, to) 281 } 282 } 283 } 284 285 // past-current-future 286 for _, diff := range blockDiffs { 287 from := current - diff 288 289 if from < 0 { 290 continue 291 } 292 293 for _, diff := range blockDiffs { 294 to := current + diff 295 296 if to >= from { 297 addTestCaseParams(caseParams, current, from, to) 298 } 299 } 300 } 301 } 302 303 type testCase struct { 304 current int 305 remoteStart int 306 remoteEnd int 307 } 308 309 var ts []testCase 310 311 // X -> N -> M 312 for x, nm := range caseParams { 313 for n, mMap := range nm { 314 for m := range mMap { 315 ts = append(ts, testCase{x, n, m}) 316 } 317 } 318 } 319 320 //nolint:paralleltest 321 for i, tc := range ts { 322 tc := tc 323 324 name := fmt.Sprintf("test case: index = %d, X = %d, N = %d, M = %d", i, tc.current, tc.remoteStart, tc.remoteEnd) 325 326 t.Run(name, func(t *testing.T) { 327 t.Parallel() 328 329 chain := createMockChain(uint64(tc.remoteStart), uint64(tc.remoteEnd)) 330 331 past, future := splitChain(uint64(tc.current), chain) 332 333 // properties 334 if len(past) > 0 { 335 // Check if the chain is ordered 336 isOrdered := sort.SliceIsSorted(past, func(i, j int) bool { 337 return past[i].Number.Uint64() < past[j].Number.Uint64() 338 }) 339 340 require.True(t, isOrdered, "an ordered past chain expected: %v", past) 341 342 isSequential := sort.SliceIsSorted(past, func(i, j int) bool { 343 return past[i].Number.Uint64() == past[j].Number.Uint64()-1 344 }) 345 346 require.True(t, isSequential, "a sequential past chain expected: %v", past) 347 348 // Check if current block >= past chain's last block 349 require.Equal(t, past[len(past)-1].Number.Uint64() <= uint64(tc.current), true) 350 } 351 352 if len(future) > 0 { 353 // Check if the chain is ordered 354 isOrdered := sort.SliceIsSorted(future, func(i, j int) bool { 355 return future[i].Number.Uint64() < future[j].Number.Uint64() 356 }) 357 358 require.True(t, isOrdered, "an ordered future chain expected: %v", future) 359 360 isSequential := sort.SliceIsSorted(future, func(i, j int) bool { 361 return future[i].Number.Uint64() == future[j].Number.Uint64()-1 362 }) 363 364 require.True(t, isSequential, "a sequential future chain expected: %v", future) 365 366 // Check if future chain's first block > current block 367 require.Equal(t, future[len(future)-1].Number.Uint64() > uint64(tc.current), true) 368 } 369 370 // Check if both chains are continuous 371 if len(past) > 0 && len(future) > 0 { 372 require.Equal(t, past[len(past)-1].Number.Uint64(), future[0].Number.Uint64()-1) 373 } 374 375 // Check if we get the original chain on appending both 376 gotChain := append(past, future...) 377 require.Equal(t, reflect.DeepEqual(gotChain, chain), true) 378 }) 379 } 380 } 381 382 // createMockChain returns a chain with dummy headers 383 // starting from `start` to `end` (inclusive) 384 func createMockChain(start, end uint64) []*types.Header { 385 var ( 386 i uint64 387 idx uint64 388 chain []*types.Header = make([]*types.Header, end-start+1) 389 ) 390 391 for i = start; i <= end; i++ { 392 header := &types.Header{ 393 Number: big.NewInt(int64(i)), 394 Time: uint64(time.Now().UnixMicro()) + i, 395 } 396 chain[idx] = header 397 idx++ 398 } 399 400 return chain 401 } 402 403 // mXNM should be initialized 404 func addTestCaseParams(mXNM map[int]map[int]map[int]struct{}, x, n, m int) { 405 //nolint:ineffassign 406 mNM, ok := mXNM[x] 407 if !ok { 408 mNM = make(map[int]map[int]struct{}) 409 mXNM[x] = mNM 410 } 411 412 //nolint:ineffassign 413 _, ok = mNM[n] 414 if !ok { 415 mM := make(map[int]struct{}) 416 mNM[n] = mM 417 } 418 419 mXNM[x][n][m] = struct{}{} 420 }