github.com/ethereum/go-ethereum@v1.16.1/beacon/light/sync/test_helpers.go (about) 1 // Copyright 2023 The go-ethereum Authors 2 // This file is part of the go-ethereum library. 3 // 4 // The go-ethereum library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The go-ethereum library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 16 17 package sync 18 19 import ( 20 "reflect" 21 "testing" 22 23 "github.com/ethereum/go-ethereum/beacon/light" 24 "github.com/ethereum/go-ethereum/beacon/light/request" 25 "github.com/ethereum/go-ethereum/beacon/types" 26 ) 27 28 type requestWithID struct { 29 sid request.ServerAndID 30 request request.Request 31 } 32 33 type TestScheduler struct { 34 t *testing.T 35 module request.Module 36 events []request.Event 37 servers []request.Server 38 allowance map[request.Server]int 39 sent map[int][]requestWithID 40 testIndex int 41 expFail map[request.Server]int // expected Server.Fail calls during next Run 42 lastId request.ID 43 } 44 45 func NewTestScheduler(t *testing.T, module request.Module) *TestScheduler { 46 return &TestScheduler{ 47 t: t, 48 module: module, 49 allowance: make(map[request.Server]int), 50 expFail: make(map[request.Server]int), 51 sent: make(map[int][]requestWithID), 52 } 53 } 54 55 func (ts *TestScheduler) Run(testIndex int, exp ...any) { 56 expReqs := make([]requestWithID, len(exp)/2) 57 id := ts.lastId 58 for i := range expReqs { 59 id++ 60 expReqs[i] = requestWithID{ 61 sid: request.ServerAndID{Server: exp[i*2].(request.Server), ID: id}, 62 request: exp[i*2+1].(request.Request), 63 } 64 } 65 if len(expReqs) == 0 { 66 expReqs = nil 67 } 68 69 ts.testIndex = testIndex 70 ts.module.Process(ts, ts.events) 71 ts.events = nil 72 73 for server, count := range ts.expFail { 74 delete(ts.expFail, server) 75 if count == 0 { 76 continue 77 } 78 ts.t.Errorf("Missing %d Server.Fail(s) from server %s in test case #%d", count, server.Name(), testIndex) 79 } 80 81 if !reflect.DeepEqual(ts.sent[testIndex], expReqs) { 82 ts.t.Errorf("Wrong sent requests in test case #%d (expected %v, got %v)", testIndex, expReqs, ts.sent[testIndex]) 83 } 84 } 85 86 func (ts *TestScheduler) CanSendTo() (cs []request.Server) { 87 for _, server := range ts.servers { 88 if ts.allowance[server] > 0 { 89 cs = append(cs, server) 90 } 91 } 92 return 93 } 94 95 func (ts *TestScheduler) Send(server request.Server, req request.Request) request.ID { 96 ts.lastId++ 97 ts.sent[ts.testIndex] = append(ts.sent[ts.testIndex], requestWithID{ 98 sid: request.ServerAndID{Server: server, ID: ts.lastId}, 99 request: req, 100 }) 101 ts.allowance[server]-- 102 return ts.lastId 103 } 104 105 func (ts *TestScheduler) Fail(server request.Server, desc string) { 106 if ts.expFail[server] == 0 { 107 ts.t.Errorf("Unexpected Fail from server %s in test case #%d: %s", server.Name(), ts.testIndex, desc) 108 return 109 } 110 ts.expFail[server]-- 111 } 112 113 func (ts *TestScheduler) Request(testIndex, reqIndex int) requestWithID { 114 if len(ts.sent[testIndex]) < reqIndex { 115 ts.t.Errorf("Missing request from test case %d index %d", testIndex, reqIndex) 116 return requestWithID{} 117 } 118 return ts.sent[testIndex][reqIndex-1] 119 } 120 121 func (ts *TestScheduler) ServerEvent(evType *request.EventType, server request.Server, data any) { 122 ts.events = append(ts.events, request.Event{ 123 Type: evType, 124 Server: server, 125 Data: data, 126 }) 127 } 128 129 func (ts *TestScheduler) RequestEvent(evType *request.EventType, req requestWithID, resp request.Response) { 130 if req.request == nil { 131 return 132 } 133 ts.events = append(ts.events, request.Event{ 134 Type: evType, 135 Server: req.sid.Server, 136 Data: request.RequestResponse{ 137 ID: req.sid.ID, 138 Request: req.request, 139 Response: resp, 140 }, 141 }) 142 } 143 144 func (ts *TestScheduler) AddServer(server request.Server, allowance int) { 145 ts.servers = append(ts.servers, server) 146 ts.allowance[server] = allowance 147 ts.ServerEvent(request.EvRegistered, server, nil) 148 } 149 150 func (ts *TestScheduler) RemoveServer(server request.Server) { 151 ts.servers = append(ts.servers, server) 152 for i, s := range ts.servers { 153 if s == server { 154 copy(ts.servers[i:len(ts.servers)-1], ts.servers[i+1:]) 155 ts.servers = ts.servers[:len(ts.servers)-1] 156 break 157 } 158 } 159 delete(ts.allowance, server) 160 ts.ServerEvent(request.EvUnregistered, server, nil) 161 } 162 163 func (ts *TestScheduler) AddAllowance(server request.Server, allowance int) { 164 ts.allowance[server] += allowance 165 } 166 167 func (ts *TestScheduler) ExpFail(server request.Server) { 168 ts.expFail[server]++ 169 } 170 171 type TestCommitteeChain struct { 172 fsp, nsp uint64 173 init bool 174 } 175 176 func (tc *TestCommitteeChain) CheckpointInit(bootstrap types.BootstrapData) error { 177 tc.fsp, tc.nsp, tc.init = bootstrap.Header.SyncPeriod(), bootstrap.Header.SyncPeriod()+2, true 178 return nil 179 } 180 181 func (tc *TestCommitteeChain) InsertUpdate(update *types.LightClientUpdate, nextCommittee *types.SerializedSyncCommittee) error { 182 period := update.AttestedHeader.Header.SyncPeriod() 183 if period < tc.fsp || period > tc.nsp || !tc.init { 184 return light.ErrInvalidPeriod 185 } 186 if period == tc.nsp { 187 tc.nsp++ 188 } 189 return nil 190 } 191 192 func (tc *TestCommitteeChain) NextSyncPeriod() (uint64, bool) { 193 return tc.nsp, tc.init 194 } 195 196 func (tc *TestCommitteeChain) ExpInit(t *testing.T, ExpInit bool) { 197 if tc.init != ExpInit { 198 t.Errorf("Incorrect init flag (expected %v, got %v)", ExpInit, tc.init) 199 } 200 } 201 202 func (tc *TestCommitteeChain) SetNextSyncPeriod(nsp uint64) { 203 tc.init, tc.nsp = true, nsp 204 } 205 206 func (tc *TestCommitteeChain) ExpNextSyncPeriod(t *testing.T, expNsp uint64) { 207 tc.ExpInit(t, true) 208 if tc.nsp != expNsp { 209 t.Errorf("Incorrect NextSyncPeriod (expected %d, got %d)", expNsp, tc.nsp) 210 } 211 } 212 213 type TestHeadTracker struct { 214 phead types.HeadInfo 215 validated []types.OptimisticUpdate 216 finality types.FinalityUpdate 217 } 218 219 func (ht *TestHeadTracker) ValidateOptimistic(update types.OptimisticUpdate) (bool, error) { 220 ht.validated = append(ht.validated, update) 221 return true, nil 222 } 223 224 func (ht *TestHeadTracker) ValidateFinality(update types.FinalityUpdate) (bool, error) { 225 ht.finality = update 226 return true, nil 227 } 228 229 func (ht *TestHeadTracker) ValidatedFinality() (types.FinalityUpdate, bool) { 230 return ht.finality, ht.finality.Attested.Header != (types.Header{}) 231 } 232 233 func (ht *TestHeadTracker) ExpValidated(t *testing.T, tci int, expHeads []types.OptimisticUpdate) { 234 for i, expHead := range expHeads { 235 if i >= len(ht.validated) { 236 t.Errorf("Missing validated head in test case #%d index #%d (expected {slot %d blockRoot %x}, got none)", tci, i, expHead.Attested.Header.Slot, expHead.Attested.Header.Hash()) 237 continue 238 } 239 if !reflect.DeepEqual(ht.validated[i], expHead) { 240 vhead := ht.validated[i].Attested.Header 241 t.Errorf("Wrong validated head in test case #%d index #%d (expected {slot %d blockRoot %x}, got {slot %d blockRoot %x})", tci, i, expHead.Attested.Header.Slot, expHead.Attested.Header.Hash(), vhead.Slot, vhead.Hash()) 242 } 243 } 244 for i := len(expHeads); i < len(ht.validated); i++ { 245 vhead := ht.validated[i].Attested.Header 246 t.Errorf("Unexpected validated head in test case #%d index #%d (expected none, got {slot %d blockRoot %x})", tci, i, vhead.Slot, vhead.Hash()) 247 } 248 ht.validated = nil 249 } 250 251 func (ht *TestHeadTracker) SetPrefetchHead(head types.HeadInfo) { 252 ht.phead = head 253 } 254 255 func (ht *TestHeadTracker) ExpPrefetch(t *testing.T, tci int, exp types.HeadInfo) { 256 if ht.phead != exp { 257 t.Errorf("Wrong prefetch head in test case #%d (expected {slot %d blockRoot %x}, got {slot %d blockRoot %x})", tci, exp.Slot, exp.BlockRoot, ht.phead.Slot, ht.phead.BlockRoot) 258 } 259 }