github.com/badrootd/nibiru-cometbft@v0.37.5-0.20240307173500-2a75559eee9b/consensus/common_test.go (about) 1 package consensus 2 3 import ( 4 "bytes" 5 "context" 6 "fmt" 7 "os" 8 "path" 9 "path/filepath" 10 "sort" 11 "sync" 12 "testing" 13 "time" 14 15 "github.com/go-kit/log/term" 16 "github.com/stretchr/testify/assert" 17 "github.com/stretchr/testify/require" 18 19 dbm "github.com/badrootd/nibiru-db" 20 21 abcicli "github.com/badrootd/nibiru-cometbft/abci/client" 22 "github.com/badrootd/nibiru-cometbft/abci/example/kvstore" 23 abci "github.com/badrootd/nibiru-cometbft/abci/types" 24 cfg "github.com/badrootd/nibiru-cometbft/config" 25 cstypes "github.com/badrootd/nibiru-cometbft/consensus/types" 26 cmtbytes "github.com/badrootd/nibiru-cometbft/libs/bytes" 27 "github.com/badrootd/nibiru-cometbft/libs/log" 28 cmtos "github.com/badrootd/nibiru-cometbft/libs/os" 29 cmtpubsub "github.com/badrootd/nibiru-cometbft/libs/pubsub" 30 cmtsync "github.com/badrootd/nibiru-cometbft/libs/sync" 31 mempl "github.com/badrootd/nibiru-cometbft/mempool" 32 mempoolv0 "github.com/badrootd/nibiru-cometbft/mempool/v0" 33 mempoolv1 "github.com/badrootd/nibiru-cometbft/mempool/v1" //nolint:staticcheck // SA1019 Priority mempool deprecated but still supported in this release. 34 "github.com/badrootd/nibiru-cometbft/p2p" 35 "github.com/badrootd/nibiru-cometbft/privval" 36 cmtproto "github.com/badrootd/nibiru-cometbft/proto/tendermint/types" 37 sm "github.com/badrootd/nibiru-cometbft/state" 38 "github.com/badrootd/nibiru-cometbft/store" 39 "github.com/badrootd/nibiru-cometbft/types" 40 cmttime "github.com/badrootd/nibiru-cometbft/types/time" 41 ) 42 43 const ( 44 testSubscriber = "test-client" 45 ) 46 47 // A cleanupFunc cleans up any config / test files created for a particular 48 // test. 49 type cleanupFunc func() 50 51 // genesis, chain_id, priv_val 52 var ( 53 config *cfg.Config // NOTE: must be reset for each _test.go file 54 consensusReplayConfig *cfg.Config 55 ensureTimeout = time.Millisecond * 200 56 ) 57 58 func ensureDir(dir string, mode os.FileMode) { 59 if err := cmtos.EnsureDir(dir, mode); err != nil { 60 panic(err) 61 } 62 } 63 64 func ResetConfig(name string) *cfg.Config { 65 return cfg.ResetTestRoot(name) 66 } 67 68 //------------------------------------------------------------------------------- 69 // validator stub (a kvstore consensus peer we control) 70 71 type validatorStub struct { 72 Index int32 // Validator index. NOTE: we don't assume validator set changes. 73 Height int64 74 Round int32 75 types.PrivValidator 76 VotingPower int64 77 lastVote *types.Vote 78 } 79 80 var testMinPower int64 = 10 81 82 func newValidatorStub(privValidator types.PrivValidator, valIndex int32) *validatorStub { 83 return &validatorStub{ 84 Index: valIndex, 85 PrivValidator: privValidator, 86 VotingPower: testMinPower, 87 } 88 } 89 90 func (vs *validatorStub) signVote( 91 voteType cmtproto.SignedMsgType, 92 hash []byte, 93 header types.PartSetHeader, 94 ) (*types.Vote, error) { 95 pubKey, err := vs.PrivValidator.GetPubKey() 96 if err != nil { 97 return nil, fmt.Errorf("can't get pubkey: %w", err) 98 } 99 100 vote := &types.Vote{ 101 ValidatorIndex: vs.Index, 102 ValidatorAddress: pubKey.Address(), 103 Height: vs.Height, 104 Round: vs.Round, 105 Timestamp: cmttime.Now(), 106 Type: voteType, 107 BlockID: types.BlockID{Hash: hash, PartSetHeader: header}, 108 } 109 v := vote.ToProto() 110 if err := vs.PrivValidator.SignVote(config.ChainID(), v); err != nil { 111 return nil, fmt.Errorf("sign vote failed: %w", err) 112 } 113 114 // ref: signVote in FilePV, the vote should use the privious vote info when the sign data is the same. 115 if signDataIsEqual(vs.lastVote, v) { 116 v.Signature = vs.lastVote.Signature 117 v.Timestamp = vs.lastVote.Timestamp 118 } 119 120 vote.Signature = v.Signature 121 vote.Timestamp = v.Timestamp 122 123 return vote, err 124 } 125 126 // Sign vote for type/hash/header 127 func signVote(vs *validatorStub, voteType cmtproto.SignedMsgType, hash []byte, header types.PartSetHeader) *types.Vote { 128 v, err := vs.signVote(voteType, hash, header) 129 if err != nil { 130 panic(fmt.Errorf("failed to sign vote: %v", err)) 131 } 132 133 vs.lastVote = v 134 135 return v 136 } 137 138 func signVotes( 139 voteType cmtproto.SignedMsgType, 140 hash []byte, 141 header types.PartSetHeader, 142 vss ...*validatorStub, 143 ) []*types.Vote { 144 votes := make([]*types.Vote, len(vss)) 145 for i, vs := range vss { 146 votes[i] = signVote(vs, voteType, hash, header) 147 } 148 return votes 149 } 150 151 func incrementHeight(vss ...*validatorStub) { 152 for _, vs := range vss { 153 vs.Height++ 154 } 155 } 156 157 func incrementRound(vss ...*validatorStub) { 158 for _, vs := range vss { 159 vs.Round++ 160 } 161 } 162 163 type ValidatorStubsByPower []*validatorStub 164 165 func (vss ValidatorStubsByPower) Len() int { 166 return len(vss) 167 } 168 169 func (vss ValidatorStubsByPower) Less(i, j int) bool { 170 vssi, err := vss[i].GetPubKey() 171 if err != nil { 172 panic(err) 173 } 174 vssj, err := vss[j].GetPubKey() 175 if err != nil { 176 panic(err) 177 } 178 179 if vss[i].VotingPower == vss[j].VotingPower { 180 return bytes.Compare(vssi.Address(), vssj.Address()) == -1 181 } 182 return vss[i].VotingPower > vss[j].VotingPower 183 } 184 185 func (vss ValidatorStubsByPower) Swap(i, j int) { 186 it := vss[i] 187 vss[i] = vss[j] 188 vss[i].Index = int32(i) 189 vss[j] = it 190 vss[j].Index = int32(j) 191 } 192 193 //------------------------------------------------------------------------------- 194 // Functions for transitioning the consensus state 195 196 func startTestRound(cs *State, height int64, round int32) { 197 cs.enterNewRound(height, round) 198 cs.startRoutines(0) 199 } 200 201 // Create proposal block from cs1 but sign it with vs. 202 func decideProposal( 203 t *testing.T, 204 cs1 *State, 205 vs *validatorStub, 206 height int64, 207 round int32, 208 ) (proposal *types.Proposal, block *types.Block) { 209 cs1.mtx.Lock() 210 block, err := cs1.createProposalBlock() 211 require.NoError(t, err) 212 blockParts, err := block.MakePartSet(types.BlockPartSizeBytes) 213 require.NoError(t, err) 214 validRound := cs1.ValidRound 215 chainID := cs1.state.ChainID 216 cs1.mtx.Unlock() 217 if block == nil { 218 panic("Failed to createProposalBlock. Did you forget to add commit for previous block?") 219 } 220 221 // Make proposal 222 polRound, propBlockID := validRound, types.BlockID{Hash: block.Hash(), PartSetHeader: blockParts.Header()} 223 proposal = types.NewProposal(height, round, polRound, propBlockID) 224 p := proposal.ToProto() 225 if err := vs.SignProposal(chainID, p); err != nil { 226 panic(err) 227 } 228 229 proposal.Signature = p.Signature 230 231 return 232 } 233 234 func addVotes(to *State, votes ...*types.Vote) { 235 for _, vote := range votes { 236 to.peerMsgQueue <- msgInfo{Msg: &VoteMessage{vote}} 237 } 238 } 239 240 func signAddVotes( 241 to *State, 242 voteType cmtproto.SignedMsgType, 243 hash []byte, 244 header types.PartSetHeader, 245 vss ...*validatorStub, 246 ) { 247 votes := signVotes(voteType, hash, header, vss...) 248 addVotes(to, votes...) 249 } 250 251 func validatePrevote(t *testing.T, cs *State, round int32, privVal *validatorStub, blockHash []byte) { 252 prevotes := cs.Votes.Prevotes(round) 253 pubKey, err := privVal.GetPubKey() 254 require.NoError(t, err) 255 address := pubKey.Address() 256 var vote *types.Vote 257 if vote = prevotes.GetByAddress(address); vote == nil { 258 panic("Failed to find prevote from validator") 259 } 260 if blockHash == nil { 261 if vote.BlockID.Hash != nil { 262 panic(fmt.Sprintf("Expected prevote to be for nil, got %X", vote.BlockID.Hash)) 263 } 264 } else { 265 if !bytes.Equal(vote.BlockID.Hash, blockHash) { 266 panic(fmt.Sprintf("Expected prevote to be for %X, got %X", blockHash, vote.BlockID.Hash)) 267 } 268 } 269 } 270 271 func validateLastPrecommit(t *testing.T, cs *State, privVal *validatorStub, blockHash []byte) { 272 votes := cs.LastCommit 273 pv, err := privVal.GetPubKey() 274 require.NoError(t, err) 275 address := pv.Address() 276 var vote *types.Vote 277 if vote = votes.GetByAddress(address); vote == nil { 278 panic("Failed to find precommit from validator") 279 } 280 if !bytes.Equal(vote.BlockID.Hash, blockHash) { 281 panic(fmt.Sprintf("Expected precommit to be for %X, got %X", blockHash, vote.BlockID.Hash)) 282 } 283 } 284 285 func validatePrecommit( 286 t *testing.T, 287 cs *State, 288 thisRound, 289 lockRound int32, 290 privVal *validatorStub, 291 votedBlockHash, 292 lockedBlockHash []byte, 293 ) { 294 precommits := cs.Votes.Precommits(thisRound) 295 pv, err := privVal.GetPubKey() 296 require.NoError(t, err) 297 address := pv.Address() 298 var vote *types.Vote 299 if vote = precommits.GetByAddress(address); vote == nil { 300 panic("Failed to find precommit from validator") 301 } 302 303 if votedBlockHash == nil { 304 if vote.BlockID.Hash != nil { 305 panic("Expected precommit to be for nil") 306 } 307 } else { 308 if !bytes.Equal(vote.BlockID.Hash, votedBlockHash) { 309 panic("Expected precommit to be for proposal block") 310 } 311 } 312 313 if lockedBlockHash == nil { 314 if cs.LockedRound != lockRound || cs.LockedBlock != nil { 315 panic(fmt.Sprintf( 316 "Expected to be locked on nil at round %d. Got locked at round %d with block %v", 317 lockRound, 318 cs.LockedRound, 319 cs.LockedBlock)) 320 } 321 } else { 322 if cs.LockedRound != lockRound || !bytes.Equal(cs.LockedBlock.Hash(), lockedBlockHash) { 323 panic(fmt.Sprintf( 324 "Expected block to be locked on round %d, got %d. Got locked block %X, expected %X", 325 lockRound, 326 cs.LockedRound, 327 cs.LockedBlock.Hash(), 328 lockedBlockHash)) 329 } 330 } 331 } 332 333 func validatePrevoteAndPrecommit( 334 t *testing.T, 335 cs *State, 336 thisRound, 337 lockRound int32, 338 privVal *validatorStub, 339 votedBlockHash, 340 lockedBlockHash []byte, 341 ) { 342 // verify the prevote 343 validatePrevote(t, cs, thisRound, privVal, votedBlockHash) 344 // verify precommit 345 cs.mtx.Lock() 346 validatePrecommit(t, cs, thisRound, lockRound, privVal, votedBlockHash, lockedBlockHash) 347 cs.mtx.Unlock() 348 } 349 350 func subscribeToVoter(cs *State, addr []byte) <-chan cmtpubsub.Message { 351 votesSub, err := cs.eventBus.SubscribeUnbuffered(context.Background(), testSubscriber, types.EventQueryVote) 352 if err != nil { 353 panic(fmt.Sprintf("failed to subscribe %s to %v", testSubscriber, types.EventQueryVote)) 354 } 355 ch := make(chan cmtpubsub.Message) 356 go func() { 357 for msg := range votesSub.Out() { 358 vote := msg.Data().(types.EventDataVote) 359 // we only fire for our own votes 360 if bytes.Equal(addr, vote.Vote.ValidatorAddress) { 361 ch <- msg 362 } 363 } 364 }() 365 return ch 366 } 367 368 //------------------------------------------------------------------------------- 369 // consensus states 370 371 func newState(state sm.State, pv types.PrivValidator, app abci.Application) *State { 372 config := cfg.ResetTestRoot("consensus_state_test") 373 return newStateWithConfig(config, state, pv, app) 374 } 375 376 func newStateWithConfig( 377 thisConfig *cfg.Config, 378 state sm.State, 379 pv types.PrivValidator, 380 app abci.Application, 381 ) *State { 382 blockDB := dbm.NewMemDB() 383 return newStateWithConfigAndBlockStore(thisConfig, state, pv, app, blockDB) 384 } 385 386 func newStateWithConfigAndBlockStore( 387 thisConfig *cfg.Config, 388 state sm.State, 389 pv types.PrivValidator, 390 app abci.Application, 391 blockDB dbm.DB, 392 ) *State { 393 // Get BlockStore 394 blockStore := store.NewBlockStore(blockDB) 395 396 // one for mempool, one for consensus 397 mtx := new(cmtsync.Mutex) 398 399 proxyAppConnCon := abcicli.NewLocalClient(mtx, app) 400 proxyAppConnConMem := abcicli.NewLocalClient(mtx, app) 401 // Make Mempool 402 memplMetrics := mempl.NopMetrics() 403 404 // Make Mempool 405 var mempool mempl.Mempool 406 407 switch config.Mempool.Version { 408 case cfg.MempoolV0: 409 mempool = mempoolv0.NewCListMempool(config.Mempool, 410 proxyAppConnConMem, 411 state.LastBlockHeight, 412 mempoolv0.WithMetrics(memplMetrics), 413 mempoolv0.WithPreCheck(sm.TxPreCheck(state)), 414 mempoolv0.WithPostCheck(sm.TxPostCheck(state))) 415 case cfg.MempoolV1: 416 logger := consensusLogger() 417 mempool = mempoolv1.NewTxMempool(logger, 418 config.Mempool, 419 proxyAppConnConMem, 420 state.LastBlockHeight, 421 mempoolv1.WithMetrics(memplMetrics), 422 mempoolv1.WithPreCheck(sm.TxPreCheck(state)), 423 mempoolv1.WithPostCheck(sm.TxPostCheck(state)), 424 ) 425 } 426 if thisConfig.Consensus.WaitForTxs() { 427 mempool.EnableTxsAvailable() 428 } 429 430 evpool := sm.EmptyEvidencePool{} 431 432 // Make State 433 stateDB := blockDB 434 stateStore := sm.NewStore(stateDB, sm.StoreOptions{ 435 DiscardABCIResponses: false, 436 }) 437 if err := stateStore.Save(state); err != nil { // for save height 1's validators info 438 panic(err) 439 } 440 441 blockExec := sm.NewBlockExecutor(stateStore, log.TestingLogger(), proxyAppConnCon, mempool, evpool) 442 cs := NewState(thisConfig.Consensus, state, blockExec, blockStore, mempool, evpool) 443 cs.SetLogger(log.TestingLogger().With("module", "consensus")) 444 cs.SetPrivValidator(pv) 445 446 eventBus := types.NewEventBus() 447 eventBus.SetLogger(log.TestingLogger().With("module", "events")) 448 err := eventBus.Start() 449 if err != nil { 450 panic(err) 451 } 452 cs.SetEventBus(eventBus) 453 return cs 454 } 455 456 func loadPrivValidator(config *cfg.Config) *privval.FilePV { 457 privValidatorKeyFile := config.PrivValidatorKeyFile() 458 ensureDir(filepath.Dir(privValidatorKeyFile), 0o700) 459 privValidatorStateFile := config.PrivValidatorStateFile() 460 privValidator := privval.LoadOrGenFilePV(privValidatorKeyFile, privValidatorStateFile) 461 privValidator.Reset() 462 return privValidator 463 } 464 465 func randState(nValidators int) (*State, []*validatorStub) { 466 return randStateWithApp(nValidators, kvstore.NewApplication()) 467 } 468 469 func randStateWithApp(nValidators int, app abci.Application) (*State, []*validatorStub) { 470 // Get State 471 state, privVals := randGenesisState(nValidators, false, 10) 472 473 vss := make([]*validatorStub, nValidators) 474 475 cs := newState(state, privVals[0], app) 476 477 for i := 0; i < nValidators; i++ { 478 vss[i] = newValidatorStub(privVals[i], int32(i)) 479 } 480 // since cs1 starts at 1 481 incrementHeight(vss[1:]...) 482 483 return cs, vss 484 } 485 486 //------------------------------------------------------------------------------- 487 488 func ensureNoNewEvent(ch <-chan cmtpubsub.Message, timeout time.Duration, 489 errorMessage string, 490 ) { 491 select { 492 case <-time.After(timeout): 493 break 494 case <-ch: 495 panic(errorMessage) 496 } 497 } 498 499 func ensureNoNewEventOnChannel(ch <-chan cmtpubsub.Message) { 500 ensureNoNewEvent( 501 ch, 502 ensureTimeout, 503 "We should be stuck waiting, not receiving new event on the channel") 504 } 505 506 func ensureNoNewRoundStep(stepCh <-chan cmtpubsub.Message) { 507 ensureNoNewEvent( 508 stepCh, 509 ensureTimeout, 510 "We should be stuck waiting, not receiving NewRoundStep event") 511 } 512 513 func ensureNoNewUnlock(unlockCh <-chan cmtpubsub.Message) { 514 ensureNoNewEvent( 515 unlockCh, 516 ensureTimeout, 517 "We should be stuck waiting, not receiving Unlock event") 518 } 519 520 func ensureNoNewTimeout(stepCh <-chan cmtpubsub.Message, timeout int64) { 521 timeoutDuration := time.Duration(timeout*10) * time.Nanosecond 522 ensureNoNewEvent( 523 stepCh, 524 timeoutDuration, 525 "We should be stuck waiting, not receiving NewTimeout event") 526 } 527 528 func ensureNewEvent(ch <-chan cmtpubsub.Message, height int64, round int32, timeout time.Duration, errorMessage string) { 529 select { 530 case <-time.After(timeout): 531 panic(errorMessage) 532 case msg := <-ch: 533 roundStateEvent, ok := msg.Data().(types.EventDataRoundState) 534 if !ok { 535 panic(fmt.Sprintf("expected a EventDataRoundState, got %T. Wrong subscription channel?", 536 msg.Data())) 537 } 538 if roundStateEvent.Height != height { 539 panic(fmt.Sprintf("expected height %v, got %v", height, roundStateEvent.Height)) 540 } 541 if roundStateEvent.Round != round { 542 panic(fmt.Sprintf("expected round %v, got %v", round, roundStateEvent.Round)) 543 } 544 // TODO: We could check also for a step at this point! 545 } 546 } 547 548 func ensureNewRound(roundCh <-chan cmtpubsub.Message, height int64, round int32) { 549 select { 550 case <-time.After(ensureTimeout): 551 panic("Timeout expired while waiting for NewRound event") 552 case msg := <-roundCh: 553 newRoundEvent, ok := msg.Data().(types.EventDataNewRound) 554 if !ok { 555 panic(fmt.Sprintf("expected a EventDataNewRound, got %T. Wrong subscription channel?", 556 msg.Data())) 557 } 558 if newRoundEvent.Height != height { 559 panic(fmt.Sprintf("expected height %v, got %v", height, newRoundEvent.Height)) 560 } 561 if newRoundEvent.Round != round { 562 panic(fmt.Sprintf("expected round %v, got %v", round, newRoundEvent.Round)) 563 } 564 } 565 } 566 567 func ensureNewTimeout(timeoutCh <-chan cmtpubsub.Message, height int64, round int32, timeout int64) { 568 timeoutDuration := time.Duration(timeout*10) * time.Nanosecond 569 ensureNewEvent(timeoutCh, height, round, timeoutDuration, 570 "Timeout expired while waiting for NewTimeout event") 571 } 572 573 func ensureNewProposal(proposalCh <-chan cmtpubsub.Message, height int64, round int32) { 574 select { 575 case <-time.After(ensureTimeout): 576 panic("Timeout expired while waiting for NewProposal event") 577 case msg := <-proposalCh: 578 proposalEvent, ok := msg.Data().(types.EventDataCompleteProposal) 579 if !ok { 580 panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?", 581 msg.Data())) 582 } 583 if proposalEvent.Height != height { 584 panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height)) 585 } 586 if proposalEvent.Round != round { 587 panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round)) 588 } 589 } 590 } 591 592 func ensureNewValidBlock(validBlockCh <-chan cmtpubsub.Message, height int64, round int32) { 593 ensureNewEvent(validBlockCh, height, round, ensureTimeout, 594 "Timeout expired while waiting for NewValidBlock event") 595 } 596 597 func ensureNewBlock(blockCh <-chan cmtpubsub.Message, height int64) { 598 select { 599 case <-time.After(ensureTimeout): 600 panic("Timeout expired while waiting for NewBlock event") 601 case msg := <-blockCh: 602 blockEvent, ok := msg.Data().(types.EventDataNewBlock) 603 if !ok { 604 panic(fmt.Sprintf("expected a EventDataNewBlock, got %T. Wrong subscription channel?", 605 msg.Data())) 606 } 607 if blockEvent.Block.Height != height { 608 panic(fmt.Sprintf("expected height %v, got %v", height, blockEvent.Block.Height)) 609 } 610 } 611 } 612 613 func ensureNewBlockHeader(blockCh <-chan cmtpubsub.Message, height int64, blockHash cmtbytes.HexBytes) { 614 select { 615 case <-time.After(ensureTimeout): 616 panic("Timeout expired while waiting for NewBlockHeader event") 617 case msg := <-blockCh: 618 blockHeaderEvent, ok := msg.Data().(types.EventDataNewBlockHeader) 619 if !ok { 620 panic(fmt.Sprintf("expected a EventDataNewBlockHeader, got %T. Wrong subscription channel?", 621 msg.Data())) 622 } 623 if blockHeaderEvent.Header.Height != height { 624 panic(fmt.Sprintf("expected height %v, got %v", height, blockHeaderEvent.Header.Height)) 625 } 626 if !bytes.Equal(blockHeaderEvent.Header.Hash(), blockHash) { 627 panic(fmt.Sprintf("expected header %X, got %X", blockHash, blockHeaderEvent.Header.Hash())) 628 } 629 } 630 } 631 632 func ensureNewUnlock(unlockCh <-chan cmtpubsub.Message, height int64, round int32) { 633 ensureNewEvent(unlockCh, height, round, ensureTimeout, 634 "Timeout expired while waiting for NewUnlock event") 635 } 636 637 func ensureProposal(proposalCh <-chan cmtpubsub.Message, height int64, round int32, propID types.BlockID) { 638 select { 639 case <-time.After(ensureTimeout): 640 panic("Timeout expired while waiting for NewProposal event") 641 case msg := <-proposalCh: 642 proposalEvent, ok := msg.Data().(types.EventDataCompleteProposal) 643 if !ok { 644 panic(fmt.Sprintf("expected a EventDataCompleteProposal, got %T. Wrong subscription channel?", 645 msg.Data())) 646 } 647 if proposalEvent.Height != height { 648 panic(fmt.Sprintf("expected height %v, got %v", height, proposalEvent.Height)) 649 } 650 if proposalEvent.Round != round { 651 panic(fmt.Sprintf("expected round %v, got %v", round, proposalEvent.Round)) 652 } 653 if !proposalEvent.BlockID.Equals(propID) { 654 panic(fmt.Sprintf("Proposed block does not match expected block (%v != %v)", proposalEvent.BlockID, propID)) 655 } 656 } 657 } 658 659 func ensurePrecommit(voteCh <-chan cmtpubsub.Message, height int64, round int32) { 660 ensureVote(voteCh, height, round, cmtproto.PrecommitType) 661 } 662 663 func ensurePrevote(voteCh <-chan cmtpubsub.Message, height int64, round int32) { 664 ensureVote(voteCh, height, round, cmtproto.PrevoteType) 665 } 666 667 func ensureVote(voteCh <-chan cmtpubsub.Message, height int64, round int32, 668 voteType cmtproto.SignedMsgType, 669 ) { 670 select { 671 case <-time.After(ensureTimeout): 672 panic("Timeout expired while waiting for NewVote event") 673 case msg := <-voteCh: 674 voteEvent, ok := msg.Data().(types.EventDataVote) 675 if !ok { 676 panic(fmt.Sprintf("expected a EventDataVote, got %T. Wrong subscription channel?", 677 msg.Data())) 678 } 679 vote := voteEvent.Vote 680 if vote.Height != height { 681 panic(fmt.Sprintf("expected height %v, got %v", height, vote.Height)) 682 } 683 if vote.Round != round { 684 panic(fmt.Sprintf("expected round %v, got %v", round, vote.Round)) 685 } 686 if vote.Type != voteType { 687 panic(fmt.Sprintf("expected type %v, got %v", voteType, vote.Type)) 688 } 689 } 690 } 691 692 func ensurePrevoteMatch(t *testing.T, voteCh <-chan cmtpubsub.Message, height int64, round int32, hash []byte) { 693 t.Helper() 694 ensureVoteMatch(t, voteCh, height, round, hash, cmtproto.PrevoteType) 695 } 696 697 func ensureVoteMatch(t *testing.T, voteCh <-chan cmtpubsub.Message, height int64, round int32, hash []byte, voteType cmtproto.SignedMsgType) { 698 t.Helper() 699 select { 700 case <-time.After(ensureTimeout): 701 t.Fatal("Timeout expired while waiting for NewVote event") 702 case msg := <-voteCh: 703 voteEvent, ok := msg.Data().(types.EventDataVote) 704 require.True(t, ok, "expected a EventDataVote, got %T. Wrong subscription channel?", 705 msg.Data()) 706 707 vote := voteEvent.Vote 708 assert.Equal(t, height, vote.Height, "expected height %d, but got %d", height, vote.Height) 709 assert.Equal(t, round, vote.Round, "expected round %d, but got %d", round, vote.Round) 710 assert.Equal(t, voteType, vote.Type, "expected type %s, but got %s", voteType, vote.Type) 711 if hash == nil { 712 require.Nil(t, vote.BlockID.Hash, "Expected prevote to be for nil, got %X", vote.BlockID.Hash) 713 } else { 714 require.True(t, bytes.Equal(vote.BlockID.Hash, hash), "Expected prevote to be for %X, got %X", hash, vote.BlockID.Hash) 715 } 716 } 717 } 718 719 func ensurePrecommitTimeout(ch <-chan cmtpubsub.Message) { 720 select { 721 case <-time.After(ensureTimeout): 722 panic("Timeout expired while waiting for the Precommit to Timeout") 723 case <-ch: 724 } 725 } 726 727 func ensureNewEventOnChannel(ch <-chan cmtpubsub.Message) { 728 select { 729 case <-time.After(ensureTimeout): 730 panic("Timeout expired while waiting for new activity on the channel") 731 case <-ch: 732 } 733 } 734 735 //------------------------------------------------------------------------------- 736 // consensus nets 737 738 // consensusLogger is a TestingLogger which uses a different 739 // color for each validator ("validator" key must exist). 740 func consensusLogger() log.Logger { 741 return log.TestingLoggerWithColorFn(func(keyvals ...interface{}) term.FgBgColor { 742 for i := 0; i < len(keyvals)-1; i += 2 { 743 if keyvals[i] == "validator" { 744 return term.FgBgColor{Fg: term.Color(uint8(keyvals[i+1].(int) + 1))} 745 } 746 } 747 return term.FgBgColor{} 748 }).With("module", "consensus") 749 } 750 751 func randConsensusNet(nValidators int, testName string, tickerFunc func() TimeoutTicker, 752 appFunc func() abci.Application, configOpts ...func(*cfg.Config), 753 ) ([]*State, cleanupFunc) { 754 genDoc, privVals := randGenesisDoc(nValidators, false, 30) 755 css := make([]*State, nValidators) 756 logger := consensusLogger() 757 configRootDirs := make([]string, 0, nValidators) 758 for i := 0; i < nValidators; i++ { 759 stateDB := dbm.NewMemDB() // each state needs its own db 760 stateStore := sm.NewStore(stateDB, sm.StoreOptions{ 761 DiscardABCIResponses: false, 762 }) 763 state, _ := stateStore.LoadFromDBOrGenesisDoc(genDoc) 764 thisConfig := ResetConfig(fmt.Sprintf("%s_%d", testName, i)) 765 configRootDirs = append(configRootDirs, thisConfig.RootDir) 766 for _, opt := range configOpts { 767 opt(thisConfig) 768 } 769 ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal 770 app := appFunc() 771 vals := types.TM2PB.ValidatorUpdates(state.Validators) 772 app.InitChain(abci.RequestInitChain{Validators: vals}) 773 774 css[i] = newStateWithConfigAndBlockStore(thisConfig, state, privVals[i], app, stateDB) 775 css[i].SetTimeoutTicker(tickerFunc()) 776 css[i].SetLogger(logger.With("validator", i, "module", "consensus")) 777 } 778 return css, func() { 779 for _, dir := range configRootDirs { 780 os.RemoveAll(dir) 781 } 782 } 783 } 784 785 // nPeers = nValidators + nNotValidator 786 func randConsensusNetWithPeers( 787 nValidators, 788 nPeers int, 789 testName string, 790 tickerFunc func() TimeoutTicker, 791 appFunc func(string) abci.Application, 792 ) ([]*State, *types.GenesisDoc, *cfg.Config, cleanupFunc) { 793 genDoc, privVals := randGenesisDoc(nValidators, false, testMinPower) 794 css := make([]*State, nPeers) 795 logger := consensusLogger() 796 var peer0Config *cfg.Config 797 configRootDirs := make([]string, 0, nPeers) 798 for i := 0; i < nPeers; i++ { 799 stateDB := dbm.NewMemDB() // each state needs its own db 800 stateStore := sm.NewStore(stateDB, sm.StoreOptions{ 801 DiscardABCIResponses: false, 802 }) 803 state, _ := stateStore.LoadFromDBOrGenesisDoc(genDoc) 804 thisConfig := ResetConfig(fmt.Sprintf("%s_%d", testName, i)) 805 configRootDirs = append(configRootDirs, thisConfig.RootDir) 806 ensureDir(filepath.Dir(thisConfig.Consensus.WalFile()), 0o700) // dir for wal 807 if i == 0 { 808 peer0Config = thisConfig 809 } 810 var privVal types.PrivValidator 811 if i < nValidators { 812 privVal = privVals[i] 813 } else { 814 tempKeyFile, err := os.CreateTemp("", "priv_validator_key_") 815 if err != nil { 816 panic(err) 817 } 818 tempStateFile, err := os.CreateTemp("", "priv_validator_state_") 819 if err != nil { 820 panic(err) 821 } 822 823 privVal = privval.GenFilePV(tempKeyFile.Name(), tempStateFile.Name()) 824 } 825 826 app := appFunc(path.Join(config.DBDir(), fmt.Sprintf("%s_%d", testName, i))) 827 vals := types.TM2PB.ValidatorUpdates(state.Validators) 828 if _, ok := app.(*kvstore.PersistentKVStoreApplication); ok { 829 // simulate handshake, receive app version. If don't do this, replay test will fail 830 state.Version.Consensus.App = kvstore.ProtocolVersion 831 } 832 app.InitChain(abci.RequestInitChain{Validators: vals}) 833 // sm.SaveState(stateDB,state) //height 1's validatorsInfo already saved in LoadStateFromDBOrGenesisDoc above 834 835 css[i] = newStateWithConfig(thisConfig, state, privVal, app) 836 css[i].SetTimeoutTicker(tickerFunc()) 837 css[i].SetLogger(logger.With("validator", i, "module", "consensus")) 838 } 839 return css, genDoc, peer0Config, func() { 840 for _, dir := range configRootDirs { 841 os.RemoveAll(dir) 842 } 843 } 844 } 845 846 func getSwitchIndex(switches []*p2p.Switch, peer p2p.Peer) int { 847 for i, s := range switches { 848 if peer.NodeInfo().ID() == s.NodeInfo().ID() { 849 return i 850 } 851 } 852 panic("didnt find peer in switches") 853 } 854 855 //------------------------------------------------------------------------------- 856 // genesis 857 858 func randGenesisDoc(numValidators int, randPower bool, minPower int64) (*types.GenesisDoc, []types.PrivValidator) { 859 validators := make([]types.GenesisValidator, numValidators) 860 privValidators := make([]types.PrivValidator, numValidators) 861 for i := 0; i < numValidators; i++ { 862 val, privVal := types.RandValidator(randPower, minPower) 863 validators[i] = types.GenesisValidator{ 864 PubKey: val.PubKey, 865 Power: val.VotingPower, 866 } 867 privValidators[i] = privVal 868 } 869 sort.Sort(types.PrivValidatorsByAddress(privValidators)) 870 871 return &types.GenesisDoc{ 872 GenesisTime: cmttime.Now(), 873 InitialHeight: 1, 874 ChainID: config.ChainID(), 875 Validators: validators, 876 }, privValidators 877 } 878 879 func randGenesisState(numValidators int, randPower bool, minPower int64) (sm.State, []types.PrivValidator) { 880 genDoc, privValidators := randGenesisDoc(numValidators, randPower, minPower) 881 s0, _ := sm.MakeGenesisState(genDoc) 882 return s0, privValidators 883 } 884 885 //------------------------------------ 886 // mock ticker 887 888 func newMockTickerFunc(onlyOnce bool) func() TimeoutTicker { 889 return func() TimeoutTicker { 890 return &mockTicker{ 891 c: make(chan timeoutInfo, 10), 892 onlyOnce: onlyOnce, 893 } 894 } 895 } 896 897 // mock ticker only fires on RoundStepNewHeight 898 // and only once if onlyOnce=true 899 type mockTicker struct { 900 c chan timeoutInfo 901 902 mtx sync.Mutex 903 onlyOnce bool 904 fired bool 905 } 906 907 func (m *mockTicker) Start() error { 908 return nil 909 } 910 911 func (m *mockTicker) Stop() error { 912 return nil 913 } 914 915 func (m *mockTicker) ScheduleTimeout(ti timeoutInfo) { 916 m.mtx.Lock() 917 defer m.mtx.Unlock() 918 if m.onlyOnce && m.fired { 919 return 920 } 921 if ti.Step == cstypes.RoundStepNewHeight { 922 m.c <- ti 923 m.fired = true 924 } 925 } 926 927 func (m *mockTicker) Chan() <-chan timeoutInfo { 928 return m.c 929 } 930 931 func (*mockTicker) SetLogger(log.Logger) {} 932 933 func newPersistentKVStore() abci.Application { 934 dir, err := os.MkdirTemp("", "persistent-kvstore") 935 if err != nil { 936 panic(err) 937 } 938 return kvstore.NewPersistentKVStoreApplication(dir) 939 } 940 941 func newKVStore() abci.Application { 942 return kvstore.NewApplication() 943 } 944 945 func newPersistentKVStoreWithPath(dbDir string) abci.Application { 946 return kvstore.NewPersistentKVStoreApplication(dbDir) 947 } 948 949 func signDataIsEqual(v1 *types.Vote, v2 *cmtproto.Vote) bool { 950 if v1 == nil || v2 == nil { 951 return false 952 } 953 954 return v1.Type == v2.Type && 955 bytes.Equal(v1.BlockID.Hash, v2.BlockID.GetHash()) && 956 v1.Height == v2.GetHeight() && 957 v1.Round == v2.Round && 958 bytes.Equal(v1.ValidatorAddress.Bytes(), v2.GetValidatorAddress()) && 959 v1.ValidatorIndex == v2.GetValidatorIndex() 960 }