github.com/VietJr/bor@v1.0.3/core/forkchoice_test.go (about) 1 package core 2 3 import ( 4 "math/big" 5 "testing" 6 7 "github.com/ethereum/go-ethereum/common" 8 "github.com/ethereum/go-ethereum/consensus/ethash" 9 "github.com/ethereum/go-ethereum/core/rawdb" 10 "github.com/ethereum/go-ethereum/core/types" 11 "github.com/ethereum/go-ethereum/params" 12 ) 13 14 // chainValidatorFake is a mock for the chain validator service 15 type chainValidatorFake struct { 16 validate func(currentHeader *types.Header, chain []*types.Header) bool 17 } 18 19 // chainReaderFake is a mock for the chain reader service 20 type chainReaderFake struct { 21 getTd func(hash common.Hash, number uint64) *big.Int 22 } 23 24 func newChainValidatorFake(validate func(currentHeader *types.Header, chain []*types.Header) bool) *chainValidatorFake { 25 return &chainValidatorFake{validate: validate} 26 } 27 28 func newChainReaderFake(getTd func(hash common.Hash, number uint64) *big.Int) *chainReaderFake { 29 return &chainReaderFake{getTd: getTd} 30 } 31 32 func TestPastChainInsert(t *testing.T) { 33 t.Parallel() 34 35 var ( 36 db = rawdb.NewMemoryDatabase() 37 genesis = (&Genesis{BaseFee: big.NewInt(params.InitialBaseFee)}).MustCommit(db) 38 ) 39 40 hc, err := NewHeaderChain(db, params.AllEthashProtocolChanges, ethash.NewFaker(), func() bool { return false }) 41 if err != nil { 42 t.Fatal(err) 43 } 44 45 // Create mocks for forker 46 getTd := func(hash common.Hash, number uint64) *big.Int { 47 return big.NewInt(int64(number)) 48 } 49 validate := func(currentHeader *types.Header, chain []*types.Header) bool { 50 // Put all explicit conditions here 51 // If canonical chain is empty and we're importing a chain of 64 blocks 52 if currentHeader.Number.Uint64() == uint64(0) && len(chain) == 64 { 53 return true 54 } 55 // If canonical chain is of len 64 and we're importing a past chain from 54-64, then accept it 56 if currentHeader.Number.Uint64() == uint64(64) && chain[0].Number.Uint64() == 55 && len(chain) == 10 { 57 return true 58 } 59 60 return false 61 } 62 mockChainReader := newChainReaderFake(getTd) 63 mockChainValidator := newChainValidatorFake(validate) 64 mockForker := NewForkChoice(mockChainReader, nil, mockChainValidator) 65 66 // chain A: G->A1->A2...A64 67 chainA := makeHeaderChain(genesis.Header(), 64, ethash.NewFaker(), db, 10) 68 69 // Inserting 64 headers on an empty chain 70 // expecting 1 write status with no error 71 testInsert(t, hc, chainA, CanonStatTy, nil, mockForker) 72 73 // The current chain is: G->A1->A2...A64 74 // chain B: G->A1->A2...A44->B45->B46...B64 75 chainB := makeHeaderChain(chainA[43], 20, ethash.NewFaker(), db, 10) 76 77 // The current chain is: G->A1->A2...A64 78 // chain C: G->A1->A2...A54->C55->C56...C64 79 chainC := makeHeaderChain(chainA[53], 10, ethash.NewFaker(), db, 10) 80 81 // Update the function to consider chainC with higher difficulty 82 getTd = func(hash common.Hash, number uint64) *big.Int { 83 td := big.NewInt(int64(number)) 84 if hash == chainB[len(chainB)-1].Hash() || hash == chainC[len(chainC)-1].Hash() { 85 td = big.NewInt(65) 86 } 87 88 return td 89 } 90 mockChainReader = newChainReaderFake(getTd) 91 mockForker = NewForkChoice(mockChainReader, nil, mockChainValidator) 92 93 // Inserting 20 blocks from chainC on canonical chain 94 // expecting 2 write status with no error 95 testInsert(t, hc, chainB, SideStatTy, nil, mockForker) 96 97 // Inserting 10 blocks from chainB on canonical chain 98 // expecting 1 write status with no error 99 testInsert(t, hc, chainC, CanonStatTy, nil, mockForker) 100 } 101 102 func TestFutureChainInsert(t *testing.T) { 103 t.Parallel() 104 105 var ( 106 db = rawdb.NewMemoryDatabase() 107 genesis = (&Genesis{BaseFee: big.NewInt(params.InitialBaseFee)}).MustCommit(db) 108 ) 109 110 hc, err := NewHeaderChain(db, params.AllEthashProtocolChanges, ethash.NewFaker(), func() bool { return false }) 111 if err != nil { 112 t.Fatal(err) 113 } 114 115 // Create mocks for forker 116 getTd := func(hash common.Hash, number uint64) *big.Int { 117 return big.NewInt(int64(number)) 118 } 119 validate := func(currentHeader *types.Header, chain []*types.Header) bool { 120 // Put all explicit conditions here 121 // If canonical chain is empty and we're importing a chain of 64 blocks 122 if currentHeader.Number.Uint64() == uint64(0) && len(chain) == 64 { 123 return true 124 } 125 // If length of future chains > some value, they should not be accepted 126 if currentHeader.Number.Uint64() == uint64(64) && len(chain) <= 10 { 127 return true 128 } 129 130 return false 131 } 132 mockChainReader := newChainReaderFake(getTd) 133 mockChainValidator := newChainValidatorFake(validate) 134 mockForker := NewForkChoice(mockChainReader, nil, mockChainValidator) 135 136 // chain A: G->A1->A2...A64 137 chainA := makeHeaderChain(genesis.Header(), 64, ethash.NewFaker(), db, 10) 138 139 // Inserting 64 headers on an empty chain 140 // expecting 1 write status with no error 141 testInsert(t, hc, chainA, CanonStatTy, nil, mockForker) 142 143 // The current chain is: G->A1->A2...A64 144 // chain B: G->A1->A2...A64->B65->B66...B84 145 chainB := makeHeaderChain(chainA[63], 20, ethash.NewFaker(), db, 10) 146 147 // Inserting 20 headers on the canonical chain 148 // expecting 0 write status with no error 149 testInsert(t, hc, chainB, SideStatTy, nil, mockForker) 150 151 // The current chain is: G->A1->A2...A64 152 // chain C: G->A1->A2...A64->C65->C66...C74 153 chainC := makeHeaderChain(chainA[63], 10, ethash.NewFaker(), db, 10) 154 155 // Inserting 10 headers on the canonical chain 156 // expecting 0 write status with no error 157 testInsert(t, hc, chainC, CanonStatTy, nil, mockForker) 158 } 159 160 func TestOverlappingChainInsert(t *testing.T) { 161 t.Parallel() 162 163 var ( 164 db = rawdb.NewMemoryDatabase() 165 genesis = (&Genesis{BaseFee: big.NewInt(params.InitialBaseFee)}).MustCommit(db) 166 ) 167 168 hc, err := NewHeaderChain(db, params.AllEthashProtocolChanges, ethash.NewFaker(), func() bool { return false }) 169 if err != nil { 170 t.Fatal(err) 171 } 172 173 // Create mocks for forker 174 getTd := func(hash common.Hash, number uint64) *big.Int { 175 return big.NewInt(int64(number)) 176 } 177 validate := func(currentHeader *types.Header, chain []*types.Header) bool { 178 // Put all explicit conditions here 179 // If canonical chain is empty and we're importing a chain of 64 blocks 180 if currentHeader.Number.Uint64() == uint64(0) && len(chain) == 64 { 181 return true 182 } 183 // If length of chain is > some fixed value then don't accept it 184 if currentHeader.Number.Uint64() == uint64(64) && len(chain) <= 20 { 185 return true 186 } 187 188 return false 189 } 190 mockChainReader := newChainReaderFake(getTd) 191 mockChainValidator := newChainValidatorFake(validate) 192 mockForker := NewForkChoice(mockChainReader, nil, mockChainValidator) 193 194 // chain A: G->A1->A2...A64 195 chainA := makeHeaderChain(genesis.Header(), 64, ethash.NewFaker(), db, 10) 196 197 // Inserting 64 headers on an empty chain 198 // expecting 1 write status with no error 199 testInsert(t, hc, chainA, CanonStatTy, nil, mockForker) 200 201 // The current chain is: G->A1->A2...A64 202 // chain B: G->A1->A2...A54->B55->B56...B84 203 chainB := makeHeaderChain(chainA[53], 30, ethash.NewFaker(), db, 10) 204 205 // Inserting 20 blocks on canonical chain 206 // expecting 2 write status with no error 207 testInsert(t, hc, chainB, SideStatTy, nil, mockForker) 208 209 // The current chain is: G->A1->A2...A64 210 // chain C: G->A1->A2...A54->C55->C56...C74 211 chainC := makeHeaderChain(chainA[53], 20, ethash.NewFaker(), db, 10) 212 213 // Inserting 10 blocks on canonical chain 214 // expecting 1 write status with no error 215 testInsert(t, hc, chainC, CanonStatTy, nil, mockForker) 216 } 217 218 // Mock chain reader functions 219 func (c *chainReaderFake) Config() *params.ChainConfig { 220 return ¶ms.ChainConfig{TerminalTotalDifficulty: nil} 221 } 222 func (c *chainReaderFake) GetTd(hash common.Hash, number uint64) *big.Int { 223 return c.getTd(hash, number) 224 } 225 226 // Mock chain validator functions 227 func (w *chainValidatorFake) IsValidPeer(remoteHeader *types.Header, fetchHeadersByNumber func(number uint64, amount int, skip int, reverse bool) ([]*types.Header, []common.Hash, error)) (bool, error) { 228 return true, nil 229 } 230 func (w *chainValidatorFake) IsValidChain(current *types.Header, headers []*types.Header) bool { 231 return w.validate(current, headers) 232 } 233 func (w *chainValidatorFake) ProcessCheckpoint(endBlockNum uint64, endBlockHash common.Hash) {} 234 func (w *chainValidatorFake) GetCheckpointWhitelist() map[uint64]common.Hash { 235 return nil 236 } 237 func (w *chainValidatorFake) PurgeCheckpointWhitelist() {} 238 func (w *chainValidatorFake) GetCheckpoints(current, sidechainHeader *types.Header, sidechainCheckpoints []*types.Header) (map[uint64]*types.Header, error) { 239 return map[uint64]*types.Header{}, nil 240 }