github.com/gnolang/gno@v0.0.0-20240520182011-228e9d0192ce/tm2/pkg/bft/consensus/common_test.go (about) 1 package consensus 2 3 import ( 4 "bytes" 5 "fmt" 6 "log/slog" 7 "os" 8 "path" 9 "path/filepath" 10 "reflect" 11 "sort" 12 "sync" 13 "testing" 14 "time" 15 16 abcicli "github.com/gnolang/gno/tm2/pkg/bft/abci/client" 17 "github.com/gnolang/gno/tm2/pkg/bft/abci/example/counter" 18 "github.com/gnolang/gno/tm2/pkg/bft/abci/example/kvstore" 19 abci "github.com/gnolang/gno/tm2/pkg/bft/abci/types" 20 cfg "github.com/gnolang/gno/tm2/pkg/bft/config" 21 cstypes "github.com/gnolang/gno/tm2/pkg/bft/consensus/types" 22 mempl "github.com/gnolang/gno/tm2/pkg/bft/mempool" 23 "github.com/gnolang/gno/tm2/pkg/bft/privval" 24 sm "github.com/gnolang/gno/tm2/pkg/bft/state" 25 "github.com/gnolang/gno/tm2/pkg/bft/store" 26 "github.com/gnolang/gno/tm2/pkg/bft/types" 27 tmtime "github.com/gnolang/gno/tm2/pkg/bft/types/time" 28 "github.com/gnolang/gno/tm2/pkg/crypto" 29 dbm "github.com/gnolang/gno/tm2/pkg/db" 30 "github.com/gnolang/gno/tm2/pkg/db/memdb" 31 "github.com/gnolang/gno/tm2/pkg/events" 32 "github.com/gnolang/gno/tm2/pkg/log" 33 osm "github.com/gnolang/gno/tm2/pkg/os" 34 ) 35 36 const ( 37 testSubscriber = "test-client" 38 ) 39 40 // A cleanupFunc cleans up any config / test files created for a particular 41 // test. 42 type cleanupFunc func() 43 44 // genesis, chain_id, priv_val 45 var config *cfg.Config // NOTE: must be reset for each _test.go file 46 var ( 47 consensusReplayConfig *cfg.Config 48 ensureTimeout = time.Millisecond * 20000 49 ) 50 51 func ensureDir(dir string, mode os.FileMode) { 52 if err := osm.EnsureDir(dir, mode); err != nil { 53 panic(err) 54 } 55 } 56 57 func ResetConfig(name string) (*cfg.Config, string) { 58 return cfg.ResetTestRoot(name) 59 } 60 61 // ------------------------------------------------------------------------------- 62 // validator stub (a kvstore consensus peer we control) 63 64 type validatorStub struct { 65 Index int // Validator index. NOTE: we don't assume validator set changes. 66 Height int64 67 Round int 68 types.PrivValidator 69 } 70 71 var testMinPower int64 = 10 72 73 func NewValidatorStub(privValidator types.PrivValidator, valIndex int) *validatorStub { 74 return &validatorStub{ 75 Index: valIndex, 76 PrivValidator: privValidator, 77 } 78 } 79 80 func (vs *validatorStub) signVote(voteType types.SignedMsgType, hash []byte, header types.PartSetHeader) (*types.Vote, error) { 81 addr := vs.PrivValidator.GetPubKey().Address() 82 vote := &types.Vote{ 83 ValidatorIndex: vs.Index, 84 ValidatorAddress: addr, 85 Height: vs.Height, 86 Round: vs.Round, 87 Timestamp: tmtime.Now(), 88 Type: voteType, 89 BlockID: types.BlockID{Hash: hash, PartsHeader: header}, 90 } 91 err := vs.PrivValidator.SignVote(config.ChainID(), vote) 92 return vote, err 93 } 94 95 // Sign vote for type/hash/header 96 func signVote(vs *validatorStub, voteType types.SignedMsgType, hash []byte, header types.PartSetHeader) *types.Vote { 97 v, err := vs.signVote(voteType, hash, header) 98 if err != nil { 99 panic(fmt.Errorf("failed to sign vote: %w", err)) 100 } 101 return v 102 } 103 104 func signVotes(voteType types.SignedMsgType, hash []byte, header types.PartSetHeader, vss ...*validatorStub) []*types.Vote { 105 votes := make([]*types.Vote, len(vss)) 106 for i, vs := range vss { 107 votes[i] = signVote(vs, voteType, hash, header) 108 } 109 return votes 110 } 111 112 func incrementHeight(vss ...*validatorStub) { 113 for _, vs := range vss { 114 vs.Height++ 115 } 116 } 117 118 func incrementRound(vss ...*validatorStub) { 119 for _, vs := range vss { 120 vs.Round++ 121 } 122 } 123 124 type ValidatorStubsByAddress []*validatorStub 125 126 func (vss ValidatorStubsByAddress) Len() int { 127 return len(vss) 128 } 129 130 func (vss ValidatorStubsByAddress) Less(i, j int) bool { 131 return vss[i].GetPubKey().Address().Compare(vss[j].GetPubKey().Address()) == -1 132 } 133 134 func (vss ValidatorStubsByAddress) Swap(i, j int) { 135 it := vss[i] 136 vss[i] = vss[j] 137 vss[i].Index = i 138 vss[j] = it 139 vss[j].Index = j 140 } 141 142 // ------------------------------------------------------------------------------- 143 // Functions for transitioning the consensus state 144 145 func startFrom(cs *ConsensusState, height int64, round int) { 146 go func() { 147 cs.enterNewRound(height, round) 148 cs.StartWithoutWALCatchup() 149 }() 150 } 151 152 // Create proposal block from cs but sign it with vs. 153 // NOTE: assumes cs already locked via mutex (perhaps via debugger). 154 func decideProposal(cs *ConsensusState, vs *validatorStub, height int64, round int) (proposal *types.Proposal, block *types.Block) { 155 block, blockParts := cs.createProposalBlock() 156 validRound := cs.ValidRound 157 chainID := cs.state.ChainID 158 if block == nil { 159 panic("Failed to createProposalBlock. Did you forget to add commit for previous block?") 160 } 161 162 // Make proposal 163 polRound, propBlockID := validRound, types.BlockID{Hash: block.Hash(), PartsHeader: blockParts.Header()} 164 proposal = types.NewProposal(height, round, polRound, propBlockID) 165 if err := vs.SignProposal(chainID, proposal); err != nil { 166 panic(err) 167 } 168 return 169 } 170 171 func addVotes(to *ConsensusState, votes ...*types.Vote) { 172 for _, vote := range votes { 173 to.peerMsgQueue <- msgInfo{Msg: &VoteMessage{vote}} 174 } 175 } 176 177 func signAddVotes(to *ConsensusState, voteType types.SignedMsgType, hash []byte, header types.PartSetHeader, vss ...*validatorStub) { 178 votes := signVotes(voteType, hash, header, vss...) 179 addVotes(to, votes...) 180 } 181 182 func validatePrevote(cs *ConsensusState, round int, privVal *validatorStub, blockHash []byte) { 183 prevotes := cs.Votes.Prevotes(round) 184 address := privVal.GetPubKey().Address() 185 var vote *types.Vote 186 if vote = prevotes.GetByAddress(address); vote == nil { 187 panic("Failed to find prevote from validator") 188 } 189 if blockHash == nil { 190 if vote.BlockID.Hash != nil { 191 panic(fmt.Sprintf("Expected prevote to be for nil, got %X", vote.BlockID.Hash)) 192 } 193 } else { 194 if !bytes.Equal(vote.BlockID.Hash, blockHash) { 195 panic(fmt.Sprintf("Expected prevote to be for %X, got %X", blockHash, vote.BlockID.Hash)) 196 } 197 } 198 } 199 200 func validateLastPrecommit(cs *ConsensusState, privVal *validatorStub, blockHash []byte) { 201 votes := cs.LastCommit 202 address := privVal.GetPubKey().Address() 203 var vote *types.Vote 204 if vote = votes.GetByAddress(address); vote == nil { 205 panic("Failed to find precommit from validator") 206 } 207 if !bytes.Equal(vote.BlockID.Hash, blockHash) { 208 panic(fmt.Sprintf("Expected precommit to be for %X, got %X", blockHash, vote.BlockID.Hash)) 209 } 210 } 211 212 func validatePrecommit(_ *testing.T, cs *ConsensusState, thisRound, lockRound int, privVal *validatorStub, votedBlockHash, lockedBlockHash []byte) { 213 precommits := cs.Votes.Precommits(thisRound) 214 address := privVal.GetPubKey().Address() 215 var vote *types.Vote 216 if vote = precommits.GetByAddress(address); vote == nil { 217 panic("Failed to find precommit from validator") 218 } 219 220 if votedBlockHash == nil { 221 if vote.BlockID.Hash != nil { 222 panic("Expected precommit to be for nil") 223 } 224 } else { 225 if !bytes.Equal(vote.BlockID.Hash, votedBlockHash) { 226 panic("Expected precommit to be for proposal block") 227 } 228 } 229 230 if lockedBlockHash == nil { 231 if cs.LockedRound != lockRound || cs.LockedBlock != nil { 232 panic(fmt.Sprintf("Expected to be locked on nil at round %d. Got locked at round %d with block %v", lockRound, cs.LockedRound, cs.LockedBlock)) 233 } 234 } else { 235 if cs.LockedRound != lockRound || !bytes.Equal(cs.LockedBlock.Hash(), lockedBlockHash) { 236 panic(fmt.Sprintf("Expected block to be locked on round %d, got %d. Got locked block %X, expected %X", lockRound, cs.LockedRound, cs.LockedBlock.Hash(), lockedBlockHash)) 237 } 238 } 239 } 240 241 func validatePrevoteAndPrecommit(t *testing.T, cs *ConsensusState, thisRound, lockRound int, privVal *validatorStub, votedBlockHash, lockedBlockHash []byte) { 242 t.Helper() 243 244 // verify the prevote 245 validatePrevote(cs, thisRound, privVal, votedBlockHash) 246 // verify precommit 247 validatePrecommit(t, cs, thisRound, lockRound, privVal, votedBlockHash, lockedBlockHash) 248 } 249 250 func subscribeToVoter(cs *ConsensusState, addr crypto.Address) <-chan events.Event { 251 return events.SubscribeFiltered(cs.evsw, testSubscriber, func(event events.Event) bool { 252 if vote, ok := event.(types.EventVote); ok { 253 if vote.Vote.ValidatorAddress == addr { 254 return true 255 } 256 } 257 return false 258 }) 259 } 260 261 // ------------------------------------------------------------------------------- 262 // consensus states 263 264 func newConsensusState(state sm.State, pv types.PrivValidator, app abci.Application) *ConsensusState { 265 config, _ := cfg.ResetTestRoot("consensus_state_test") 266 return newConsensusStateWithConfig(config, state, pv, app) 267 } 268 269 func newConsensusStateWithConfig(thisConfig *cfg.Config, state sm.State, pv types.PrivValidator, app abci.Application) *ConsensusState { 270 blockDB := memdb.NewMemDB() 271 return newConsensusStateWithConfigAndBlockStore(thisConfig, state, pv, app, blockDB) 272 } 273 274 func newConsensusStateWithConfigAndBlockStore(thisConfig *cfg.Config, state sm.State, pv types.PrivValidator, app abci.Application, blockDB dbm.DB) *ConsensusState { 275 // Get BlockStore 276 blockStore := store.NewBlockStore(blockDB) 277 278 // one for mempool, one for consensus 279 mtx := new(sync.Mutex) 280 proxyAppConnMem := abcicli.NewLocalClient(mtx, app) 281 proxyAppConnCon := abcicli.NewLocalClient(mtx, app) 282 283 // Make Mempool 284 mempool := mempl.NewCListMempool(thisConfig.Mempool, proxyAppConnMem, 0, state.ConsensusParams.Block.MaxTxBytes) 285 mempool.SetLogger(log.NewNoopLogger().With("module", "mempool")) 286 if thisConfig.Consensus.WaitForTxs() { 287 mempool.EnableTxsAvailable() 288 } 289 290 // Make ConsensusState 291 stateDB := blockDB 292 sm.SaveState(stateDB, state) // for save height 1's validators info 293 blockExec := sm.NewBlockExecutor(stateDB, log.NewNoopLogger(), proxyAppConnCon, mempool) 294 cs := NewConsensusState(thisConfig.Consensus, state, blockExec, blockStore, mempool) 295 cs.SetLogger(log.NewNoopLogger().With("module", "consensus")) 296 cs.SetPrivValidator(pv) 297 298 evsw := events.NewEventSwitch() 299 evsw.SetLogger(log.NewNoopLogger().With("module", "events")) 300 evsw.Start() 301 cs.SetEventSwitch(evsw) 302 return cs 303 } 304 305 func loadPrivValidator(config *cfg.Config) *privval.FilePV { 306 privValidatorKeyFile := config.PrivValidatorKeyFile() 307 ensureDir(filepath.Dir(privValidatorKeyFile), 0o700) 308 privValidatorStateFile := config.PrivValidatorStateFile() 309 privValidator := privval.LoadOrGenFilePV(privValidatorKeyFile, privValidatorStateFile) 310 privValidator.Reset() 311 return privValidator 312 } 313 314 func randConsensusState(nValidators int) (*ConsensusState, []*validatorStub) { 315 // Get State 316 state, privVals := randGenesisState(nValidators, false, 10) 317 318 vss := make([]*validatorStub, nValidators) 319 320 cs := newConsensusState(state, privVals[0], counter.NewCounterApplication(true)) 321 322 for i := 0; i < nValidators; i++ { 323 vss[i] = NewValidatorStub(privVals[i], i) 324 } 325 // since cs1 starts at 1 326 incrementHeight(vss[1:]...) 327 328 return cs, vss 329 } 330 331 // ------------------------------------------------------------------------------- 332 333 func ensureNoNewEvent(ch <-chan events.Event, timeout time.Duration, 334 errorMessage string, 335 ) { 336 select { 337 case <-time.After(timeout): 338 break 339 case <-ch: 340 panic(errorMessage) 341 } 342 } 343 344 func ensureNoNewEventOnChannel(ch <-chan events.Event) { 345 ensureNoNewEvent( 346 ch, 347 ensureTimeout, 348 "We should be stuck waiting, not receiving new event on the channel") 349 } 350 351 func ensureNoNewRoundStep(stepCh <-chan events.Event) { 352 ensureNoNewEvent( 353 stepCh, 354 ensureTimeout, 355 "We should be stuck waiting, not receiving NewRoundStep event") 356 } 357 358 func ensureNoNewUnlock(unlockCh <-chan events.Event) { 359 ensureNoNewEvent( 360 unlockCh, 361 ensureTimeout, 362 "We should be stuck waiting, not receiving Unlock event") 363 } 364 365 func ensureNoNewTimeout(stepCh <-chan events.Event, timeout int64) { 366 timeoutDuration := time.Duration(timeout*10) * time.Nanosecond 367 ensureNoNewEvent( 368 stepCh, 369 timeoutDuration, 370 "We should be stuck waiting, not receiving NewTimeout event") 371 } 372 373 func ensureNewEvent(ch <-chan events.Event, height int64, round int, timeout time.Duration, errorMessage string) { 374 select { 375 case <-time.After(timeout): 376 osm.PrintAllGoroutines() 377 panic(errorMessage) 378 case msg := <-ch: 379 csevent, ok := msg.(cstypes.ConsensusEvent) 380 if !ok { 381 panic(fmt.Sprintf("expected a ConsensusEvent, got %T. Wrong subscription channel?", 382 msg)) 383 } 384 if csevent.GetHRS().Height != height { 385 panic(fmt.Sprintf("expected height %v, got %v", height, csevent.GetHRS().Height)) 386 } 387 if csevent.GetHRS().Round != round { 388 panic(fmt.Sprintf("expected round %v, got %v", round, csevent.GetHRS().Round)) 389 } 390 // TODO: We could check also for a step at this point! 391 } 392 } 393 394 func ensureNewRound(roundCh <-chan events.Event, height int64, round int) { 395 select { 396 case <-time.After(ensureTimeout): 397 panic("Timeout expired while waiting for NewRound event") 398 case msg := <-roundCh: 399 newRoundEvent, ok := msg.(cstypes.EventNewRound) 400 if !ok { 401 panic(fmt.Sprintf("expected a EventNewRound, got %T. Wrong subscription channel?", 402 msg)) 403 } 404 if newRoundEvent.Height != height { 405 panic(fmt.Sprintf("expected height %v, got %v", height, newRoundEvent.Height)) 406 } 407 if newRoundEvent.Round != round { 408 panic(fmt.Sprintf("expected round %v, got %v", round, newRoundEvent.Round)) 409 } 410 } 411 } 412 413 func ensureNewRoundStep(stepCh <-chan events.Event, height int64, round int, step cstypes.RoundStepType) { 414 select { 415 case <-time.After(ensureTimeout): 416 panic("Timeout expired while waiting for NewRoundStep event") 417 case msg := <-stepCh: 418 newStepEvent, ok := msg.(cstypes.EventNewRoundStep) 419 if !ok { 420 panic(fmt.Sprintf("expected a EventNewRound, got %T. Wrong subscription channel?", 421 msg)) 422 } 423 if newStepEvent.Height != height { 424 panic(fmt.Sprintf("expected height %v, got %v", height, newStepEvent.Height)) 425 } 426 if newStepEvent.Round != round { 427 panic(fmt.Sprintf("expected round %v, got %v", round, newStepEvent.Round)) 428 } 429 if newStepEvent.Step != step { 430 panic(fmt.Sprintf("expected step %v, got %v", step, newStepEvent.Step)) 431 } 432 } 433 } 434 435 func ensureNewTimeout(timeoutCh <-chan events.Event, height int64, round int, timeout int64) { 436 timeoutDuration := (time.Duration(timeout))*time.Nanosecond + ensureTimeout 437 ensureNewEvent(timeoutCh, height, round, timeoutDuration, 438 "Timeout expired while waiting for NewTimeout event") 439 } 440 441 func ensureNewProposal(proposalCh <-chan events.Event, height int64, round int) { 442 select { 443 case <-time.After(ensureTimeout): 444 panic("Timeout expired while waiting for NewProposal event") 445 case msg := <-proposalCh: 446 proposalEvent, ok := msg.(cstypes.EventCompleteProposal) 447 if !ok { 448 panic(fmt.Sprintf("expected a EventCompleteProposal, got %T. Wrong subscription channel?", 449 msg)) 450 } 451 if proposalEvent.Height != height { 452 panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height)) 453 } 454 if proposalEvent.Round != round { 455 panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round)) 456 } 457 } 458 } 459 460 func ensureNewValidBlock(validBlockCh <-chan events.Event, height int64, round int) { 461 ensureNewEvent(validBlockCh, height, round, ensureTimeout, 462 "Timeout expired while waiting for NewValidBlock event") 463 } 464 465 func ensureNewBlock(blockCh <-chan events.Event, height int64) { 466 select { 467 case <-time.After(ensureTimeout): 468 panic("Timeout expired while waiting for NewBlock event") 469 case msg := <-blockCh: 470 blockEvent, ok := msg.(types.EventNewBlock) 471 if !ok { 472 panic(fmt.Sprintf("expected a EventNewBlock, got %T. Wrong subscription channel?", 473 msg)) 474 } 475 if blockEvent.Block.Height != height { 476 panic(fmt.Sprintf("expected height %v, got %v", height, blockEvent.Block.Height)) 477 } 478 } 479 } 480 481 func ensureNewBlockHeader(blockCh <-chan events.Event, height int64, blockHash []byte) { 482 select { 483 case <-time.After(ensureTimeout): 484 panic("Timeout expired while waiting for NewBlockHeader event") 485 case msg := <-blockCh: 486 blockHeaderEvent, ok := msg.(types.EventNewBlockHeader) 487 if !ok { 488 panic(fmt.Sprintf("expected a EventNewBlockHeader, got %T. Wrong subscription channel?", 489 msg)) 490 } 491 if blockHeaderEvent.Header.Height != height { 492 panic(fmt.Sprintf("expected height %v, got %v", height, blockHeaderEvent.Header.Height)) 493 } 494 if !bytes.Equal(blockHeaderEvent.Header.Hash(), blockHash) { 495 panic(fmt.Sprintf("expected header %X, got %X", blockHash, blockHeaderEvent.Header.Hash())) 496 } 497 } 498 } 499 500 func ensureNewUnlock(unlockCh <-chan events.Event, height int64, round int) { 501 ensureNewEvent(unlockCh, height, round, ensureTimeout, 502 "Timeout expired while waiting for NewUnlock event") 503 } 504 505 func ensureProposal(proposalCh <-chan events.Event, height int64, round int, propID types.BlockID) { 506 select { 507 case <-time.After(ensureTimeout): 508 panic("Timeout expired while waiting for NewProposal event") 509 case msg := <-proposalCh: 510 proposalEvent, ok := msg.(cstypes.EventCompleteProposal) 511 if !ok { 512 panic(fmt.Sprintf("expected a EventCompleteProposal, got %T. Wrong subscription channel?", 513 msg)) 514 } 515 if proposalEvent.Height != height { 516 panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height)) 517 } 518 if proposalEvent.Round != round { 519 panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round)) 520 } 521 if !proposalEvent.BlockID.Equals(propID) { 522 panic("Proposed block does not match expected block") 523 } 524 } 525 } 526 527 func ensurePrecommit(voteCh <-chan events.Event, height int64, round int) { 528 ensureVote(voteCh, height, round, types.PrecommitType) 529 } 530 531 func ensurePrevote(voteCh <-chan events.Event, height int64, round int) { 532 ensureVote(voteCh, height, round, types.PrevoteType) 533 } 534 535 func ensureVote(voteCh <-chan events.Event, height int64, round int, 536 voteType types.SignedMsgType, 537 ) { 538 select { 539 case <-time.After(ensureTimeout): 540 panic("Timeout expired while waiting for NewVote event") 541 case msg := <-voteCh: 542 voteEvent, ok := msg.(types.EventVote) 543 if !ok { 544 panic(fmt.Sprintf("expected a EventVote, got %T. Wrong subscription channel?", 545 msg)) 546 } 547 vote := voteEvent.Vote 548 if vote.Height != height { 549 panic(fmt.Sprintf("expected height %v, got %v", height, vote.Height)) 550 } 551 if vote.Round != round { 552 panic(fmt.Sprintf("expected round %v, got %v", round, vote.Round)) 553 } 554 if vote.Type != voteType { 555 panic(fmt.Sprintf("expected type %v, got %v", voteType, vote.Type)) 556 } 557 } 558 } 559 560 func ensureNewEventOnChannel(ch <-chan events.Event) { 561 select { 562 case <-time.After(ensureTimeout): 563 panic("Timeout expired while waiting for new activity on the channel") 564 case <-ch: 565 } 566 } 567 568 // ------------------------------------------------------------------------------- 569 // consensus nets 570 571 func randConsensusNet(nValidators int, testName string, tickerFunc func() TimeoutTicker, 572 appFunc func() abci.Application, configOpts ...func(*cfg.Config), 573 ) ([]*ConsensusState, cleanupFunc) { 574 genDoc, privVals := randGenesisDoc(nValidators, false, 30) 575 css := make([]*ConsensusState, nValidators) 576 apps := make([]abci.Application, nValidators) 577 logger := log.NewNoopLogger() 578 configRootDirs := make([]string, 0, nValidators) 579 for i := 0; i < nValidators; i++ { 580 stateDB := memdb.NewMemDB() // each state needs its own db 581 state, _ := sm.LoadStateFromDBOrGenesisDoc(stateDB, genDoc) 582 thisConfig, _ := ResetConfig(fmt.Sprintf("%s_%d", testName, i)) 583 configRootDirs = append(configRootDirs, thisConfig.RootDir) 584 for _, opt := range configOpts { 585 opt(thisConfig) 586 } 587 ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal 588 app := appFunc() 589 vals := state.Validators.ABCIValidatorUpdates() 590 app.InitChain(abci.RequestInitChain{Validators: vals}) 591 592 css[i] = newConsensusStateWithConfigAndBlockStore(thisConfig, state, privVals[i], app, stateDB) 593 css[i].SetTimeoutTicker(tickerFunc()) 594 css[i].SetLogger(logger.With("validator", i, "module", "consensus")) 595 apps[i] = app 596 } 597 return css, func() { 598 for _, dir := range configRootDirs { 599 os.RemoveAll(dir) 600 } 601 for _, cs := range css { 602 cs.Stop() 603 cs.Wait() 604 } 605 for _, app := range apps { 606 app.Close() 607 } 608 } 609 } 610 611 // nPeers = nValidators + nNotValidator 612 func randConsensusNetWithPeers(nValidators, nPeers int, testName string, tickerFunc func() TimeoutTicker, appFunc func(string) abci.Application) ([]*ConsensusState, *types.GenesisDoc, *cfg.Config, cleanupFunc) { 613 genDoc, privVals := randGenesisDoc(nValidators, false, testMinPower) 614 css := make([]*ConsensusState, nPeers) 615 apps := make([]abci.Application, nPeers) 616 logger := log.NewNoopLogger() 617 var peer0Config *cfg.Config 618 configRootDirs := make([]string, 0, nPeers) 619 for i := 0; i < nPeers; i++ { 620 stateDB := memdb.NewMemDB() // each state needs its own db 621 state, _ := sm.LoadStateFromDBOrGenesisDoc(stateDB, genDoc) 622 thisConfig, _ := ResetConfig(fmt.Sprintf("%s_%d", testName, i)) 623 configRootDirs = append(configRootDirs, thisConfig.RootDir) 624 ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal 625 if i == 0 { 626 peer0Config = thisConfig 627 } 628 var privVal types.PrivValidator 629 if i < nValidators { 630 privVal = privVals[i] 631 } else { 632 tempKeyFile, err := os.CreateTemp("", "priv_validator_key_") 633 if err != nil { 634 panic(err) 635 } 636 tempStateFile, err := os.CreateTemp("", "priv_validator_state_") 637 if err != nil { 638 panic(err) 639 } 640 641 privVal = privval.GenFilePV(tempKeyFile.Name(), tempStateFile.Name()) 642 } 643 644 app := appFunc(path.Join(config.DBDir(), fmt.Sprintf("%s_%d", testName, i))) 645 vals := state.Validators.ABCIValidatorUpdates() 646 if _, ok := app.(*kvstore.PersistentKVStoreApplication); ok { 647 state.AppVersion = kvstore.AppVersion 648 // simulate handshake, receive app version. If don't do this, replay test will fail 649 } 650 app.InitChain(abci.RequestInitChain{Validators: vals}) 651 // sm.SaveState(stateDB,state) //height 1's validatorsInfo already saved in LoadStateFromDBOrGenesisDoc above 652 653 css[i] = newConsensusStateWithConfig(thisConfig, state, privVal, app) 654 css[i].SetTimeoutTicker(tickerFunc()) 655 css[i].SetLogger(logger.With("validator", i, "module", "consensus")) 656 apps[i] = app 657 } 658 return css, genDoc, peer0Config, func() { 659 for _, dir := range configRootDirs { 660 os.RemoveAll(dir) 661 } 662 for _, cs := range css { 663 cs.Stop() 664 cs.Wait() 665 } 666 for _, app := range apps { 667 app.Close() 668 } 669 } 670 } 671 672 // ------------------------------------------------------------------------------- 673 // genesis 674 675 func randGenesisDoc(numValidators int, randPower bool, minPower int64) (*types.GenesisDoc, []types.PrivValidator) { 676 validators := make([]types.GenesisValidator, numValidators) 677 privValidators := make([]types.PrivValidator, numValidators) 678 for i := 0; i < numValidators; i++ { 679 val, privVal := types.RandValidator(randPower, minPower) 680 validators[i] = types.GenesisValidator{ 681 PubKey: val.PubKey, 682 Power: val.VotingPower, 683 } 684 privValidators[i] = privVal 685 } 686 sort.Sort(types.PrivValidatorsByAddress(privValidators)) 687 688 return &types.GenesisDoc{ 689 GenesisTime: tmtime.Now(), 690 ChainID: config.ChainID(), 691 Validators: validators, 692 }, privValidators 693 } 694 695 func randGenesisState(numValidators int, randPower bool, minPower int64) (sm.State, []types.PrivValidator) { 696 genDoc, privValidators := randGenesisDoc(numValidators, randPower, minPower) 697 s0, _ := sm.MakeGenesisState(genDoc) 698 return s0, privValidators 699 } 700 701 // ------------------------------------ 702 // mock ticker 703 704 func newMockTickerFunc(onlyOnce bool) func() TimeoutTicker { 705 return func() TimeoutTicker { 706 return &mockTicker{ 707 c: make(chan timeoutInfo, 10), 708 onlyOnce: onlyOnce, 709 } 710 } 711 } 712 713 // mock ticker only fires on RoundStepNewHeight 714 // and only once if onlyOnce=true 715 type mockTicker struct { 716 c chan timeoutInfo 717 718 mtx sync.Mutex 719 onlyOnce bool 720 fired bool 721 } 722 723 func (m *mockTicker) Start() error { 724 return nil 725 } 726 727 func (m *mockTicker) Stop() error { 728 return nil 729 } 730 731 func (m *mockTicker) ScheduleTimeout(ti timeoutInfo) { 732 m.mtx.Lock() 733 defer m.mtx.Unlock() 734 if m.onlyOnce && m.fired { 735 return 736 } 737 if ti.Step == cstypes.RoundStepNewHeight { 738 m.c <- ti 739 m.fired = true 740 } 741 } 742 743 func (m *mockTicker) Chan() <-chan timeoutInfo { 744 return m.c 745 } 746 747 func (*mockTicker) SetLogger(_ *slog.Logger) {} 748 749 // ------------------------------------ 750 751 func newCounter() abci.Application { 752 return counter.NewCounterApplication(true) 753 } 754 755 func newPersistentKVStore() abci.Application { 756 dir, err := os.MkdirTemp("", "persistent-kvstore") 757 if err != nil { 758 panic(err) 759 } 760 return kvstore.NewPersistentKVStoreApplication(dir) 761 } 762 763 func newPersistentKVStoreWithPath(dbDir string) abci.Application { 764 return kvstore.NewPersistentKVStoreApplication(dbDir) 765 } 766 767 // ------------------------------------ 768 769 func ensureDrainedChannels(t *testing.T, channels ...any) { 770 t.Helper() 771 772 r := recover() 773 if r == nil { 774 return 775 } 776 777 t.Logf("checking for drained channel") 778 leaks := make(map[string]int) 779 for _, ch := range channels { 780 chVal := reflect.ValueOf(ch) 781 if chVal.Kind() != reflect.Chan { 782 panic(chVal.Type().Name() + " not a channel") 783 } 784 785 maxExp := time.After(time.Second * 5) 786 787 // Use a select statement with reflection 788 cases := []reflect.SelectCase{ 789 {Dir: reflect.SelectRecv, Chan: chVal}, 790 {Dir: reflect.SelectRecv, Chan: reflect.ValueOf(maxExp)}, 791 {Dir: reflect.SelectDefault}, 792 } 793 794 for { 795 chosen, recv, recvOK := reflect.Select(cases) 796 if chosen != 0 || !recvOK { 797 break 798 } 799 800 leaks[reflect.TypeOf(recv.Interface()).String()]++ 801 time.Sleep(time.Millisecond * 500) 802 } 803 } 804 805 for leak, count := range leaks { 806 t.Logf("channel %q: %d events left\n", leak, count) 807 } 808 809 panic(r) 810 }