github.com/bigzoro/my_simplechain@v0.0.0-20240315012955-8ad0a2a29bb9/consensus/raft/backend/snapshot.go (about)

     1  package backend
     2  
     3  import (
     4  	"fmt"
     5  	"io"
     6  	"math/big"
     7  	"sort"
     8  	"time"
     9  
    10  	"github.com/bigzoro/my_simplechain/common"
    11  	"github.com/bigzoro/my_simplechain/consensus/raft"
    12  	"github.com/bigzoro/my_simplechain/core/types"
    13  	"github.com/bigzoro/my_simplechain/eth/downloader"
    14  	"github.com/bigzoro/my_simplechain/log"
    15  	"github.com/bigzoro/my_simplechain/rlp"
    16  
    17  	"github.com/coreos/etcd/raft/raftpb"
    18  	"github.com/coreos/etcd/snap"
    19  	"github.com/coreos/etcd/wal/walpb"
    20  	mapset "github.com/deckarep/golang-set"
    21  )
    22  
    23  // Snapshot
    24  
    25  type Snapshot struct {
    26  	addresses      []raft.Address
    27  	removedRaftIds []uint16 // Raft IDs for permanently removed peers
    28  	headBlockHash  common.Hash
    29  }
    30  
    31  type ByRaftId []raft.Address
    32  
    33  func (a ByRaftId) Len() int           { return len(a) }
    34  func (a ByRaftId) Swap(i, j int)      { a[i], a[j] = a[j], a[i] }
    35  func (a ByRaftId) Less(i, j int) bool { return a[i].RaftId < a[j].RaftId }
    36  
    37  func (pm *ProtocolManager) buildSnapshot() *Snapshot {
    38  	pm.mu.RLock()
    39  	defer pm.mu.RUnlock()
    40  
    41  	numNodes := len(pm.confState.Nodes)
    42  	numRemovedNodes := pm.removedPeers.Cardinality()
    43  
    44  	snapshot := &Snapshot{
    45  		addresses:      make([]raft.Address, numNodes),
    46  		removedRaftIds: make([]uint16, numRemovedNodes),
    47  		headBlockHash:  pm.blockchain.CurrentBlock().Hash(),
    48  	}
    49  
    50  	// Populate addresses
    51  
    52  	for i, rawRaftId := range pm.confState.Nodes {
    53  		raftId := uint16(rawRaftId)
    54  
    55  		if raftId == pm.raftId {
    56  			snapshot.addresses[i] = *pm.address
    57  		} else {
    58  			snapshot.addresses[i] = *pm.peers[raftId].Address
    59  		}
    60  	}
    61  	sort.Sort(ByRaftId(snapshot.addresses))
    62  
    63  	// Populate removed IDs
    64  	i := 0
    65  	for removedIface := range pm.removedPeers.Iterator().C {
    66  		snapshot.removedRaftIds[i] = removedIface.(uint16)
    67  		i++
    68  	}
    69  
    70  	return snapshot
    71  }
    72  
    73  // Note that we do *not* read `pm.appliedIndex` here. We only use the `index`
    74  // parameter instead. This is because we need to support a scenario when we
    75  // snapshot for a future index that we have not yet recorded in LevelDB. See
    76  // comments around the use of `forceSnapshot`.
    77  func (pm *ProtocolManager) triggerSnapshot(index uint64) {
    78  	pm.mu.RLock()
    79  	snapshotIndex := pm.snapshotIndex
    80  	pm.mu.RUnlock()
    81  
    82  	log.Info("start snapshot", "applied index", pm.appliedIndex, "last snapshot index", snapshotIndex)
    83  
    84  	//snapData := pm.blockchain.CurrentBlock().Hash().Bytes()
    85  	//snap, err := pm.raftStorage.CreateSnapshot(pm.appliedIndex, &pm.confState, snapData)
    86  	snapData := pm.buildSnapshot().toBytes()
    87  	snap, err := pm.raftStorage.CreateSnapshot(index, &pm.confState, snapData)
    88  	if err != nil {
    89  		panic(err)
    90  	}
    91  	if err := pm.saveRaftSnapshot(snap); err != nil {
    92  		panic(err)
    93  	}
    94  	// Discard all log entries prior to index.
    95  	if err := pm.raftStorage.Compact(index); err != nil {
    96  		panic(err)
    97  	}
    98  	log.Info("compacted log", "index", pm.appliedIndex)
    99  
   100  	pm.mu.Lock()
   101  	pm.snapshotIndex = index
   102  	pm.mu.Unlock()
   103  }
   104  
   105  func confStateIdSet(confState raftpb.ConfState) mapset.Set {
   106  	set := mapset.NewSet()
   107  	for _, rawRaftId := range confState.Nodes {
   108  		set.Add(uint16(rawRaftId))
   109  	}
   110  	return set
   111  }
   112  
   113  func (pm *ProtocolManager) updateClusterMembership(newConfState raftpb.ConfState, addresses []raft.Address, removedRaftIds []uint16) {
   114  	log.Info("updating cluster membership per raft snapshot")
   115  
   116  	prevConfState := pm.confState
   117  
   118  	// Update tombstones for permanently removed peers. For simplicity we do not
   119  	// allow the re-use of peer IDs once a peer is removed.
   120  
   121  	removedPeers := mapset.NewSet()
   122  	for _, removedRaftId := range removedRaftIds {
   123  		removedPeers.Add(removedRaftId)
   124  	}
   125  	pm.mu.Lock()
   126  	pm.removedPeers = removedPeers
   127  	pm.mu.Unlock()
   128  
   129  	// Remove old peers that we're still connected to
   130  
   131  	prevIds := confStateIdSet(prevConfState)
   132  	newIds := confStateIdSet(newConfState)
   133  	idsToRemove := prevIds.Difference(newIds)
   134  	for idIfaceToRemove := range idsToRemove.Iterator().C {
   135  		raftId := idIfaceToRemove.(uint16)
   136  		log.Info("removing old raft peer", "peer id", raftId)
   137  
   138  		pm.removePeer(raftId)
   139  	}
   140  
   141  	// Update local and remote addresses
   142  
   143  	for _, tempAddress := range addresses {
   144  		address := tempAddress // Allocate separately on the heap for each iteration.
   145  
   146  		if address.RaftId == pm.raftId {
   147  			// If we're a newcomer to an existing cluster, this is where we learn
   148  			// our own Address.
   149  			pm.setLocalAddress(&address)
   150  		} else {
   151  			pm.mu.RLock()
   152  			existingPeer := pm.peers[address.RaftId]
   153  			pm.mu.RUnlock()
   154  
   155  			if existingPeer == nil {
   156  				log.Info("adding new raft peer", "raft id", address.RaftId)
   157  				pm.addPeer(&address)
   158  			}
   159  		}
   160  	}
   161  
   162  	pm.mu.Lock()
   163  	pm.confState = newConfState
   164  	pm.mu.Unlock()
   165  
   166  	log.Info("updated cluster membership")
   167  }
   168  
   169  func (pm *ProtocolManager) maybeTriggerSnapshot() {
   170  	pm.mu.RLock()
   171  	appliedIndex := pm.appliedIndex
   172  	entriesSinceLastSnap := appliedIndex - pm.snapshotIndex
   173  	pm.mu.RUnlock()
   174  
   175  	if entriesSinceLastSnap < raft.SnapshotPeriod {
   176  		return
   177  	}
   178  
   179  	pm.triggerSnapshot(appliedIndex)
   180  }
   181  
   182  func (pm *ProtocolManager) loadSnapshot() *raftpb.Snapshot {
   183  	if raftSnapshot := pm.readRaftSnapshot(); raftSnapshot != nil {
   184  		log.Info("loading snapshot")
   185  
   186  		pm.applyRaftSnapshot(*raftSnapshot)
   187  
   188  		return raftSnapshot
   189  	} else {
   190  		log.Info("no snapshot to load")
   191  
   192  		return nil
   193  	}
   194  }
   195  
   196  func (snapshot *Snapshot) toBytes() []byte {
   197  	size, r, err := rlp.EncodeToReader(snapshot)
   198  	if err != nil {
   199  		panic(fmt.Sprintf("error: failed to RLP-encode Snapshot: %s", err.Error()))
   200  	}
   201  	var buffer = make([]byte, uint32(size))
   202  	r.Read(buffer)
   203  
   204  	return buffer
   205  }
   206  
   207  func bytesToSnapshot(bytes []byte) *Snapshot {
   208  	var snapshot Snapshot
   209  	if err := rlp.DecodeBytes(bytes, &snapshot); err != nil {
   210  		raft.Fatalf("failed to RLP-decode Snapshot: %v", err)
   211  	}
   212  	return &snapshot
   213  }
   214  
   215  func (snapshot *Snapshot) EncodeRLP(w io.Writer) error {
   216  	return rlp.Encode(w, []interface{}{snapshot.addresses, snapshot.removedRaftIds, snapshot.headBlockHash})
   217  }
   218  
   219  func (snapshot *Snapshot) DecodeRLP(s *rlp.Stream) error {
   220  	// These fields need to be public:
   221  	var temp struct {
   222  		Addresses      []raft.Address
   223  		RemovedRaftIds []uint16
   224  		HeadBlockHash  common.Hash
   225  	}
   226  
   227  	if err := s.Decode(&temp); err != nil {
   228  		return err
   229  	} else {
   230  		snapshot.addresses, snapshot.removedRaftIds, snapshot.headBlockHash = temp.Addresses, temp.RemovedRaftIds, temp.HeadBlockHash
   231  		return nil
   232  	}
   233  }
   234  
   235  // Raft snapshot
   236  
   237  func (pm *ProtocolManager) saveRaftSnapshot(snap raftpb.Snapshot) error {
   238  	if err := pm.snapshotter.SaveSnap(snap); err != nil {
   239  		return err
   240  	}
   241  
   242  	walSnap := walpb.Snapshot{
   243  		Index: snap.Metadata.Index,
   244  		Term:  snap.Metadata.Term,
   245  	}
   246  
   247  	if err := pm.wal.SaveSnapshot(walSnap); err != nil {
   248  		return err
   249  	}
   250  
   251  	return pm.wal.ReleaseLockTo(snap.Metadata.Index)
   252  }
   253  
   254  func (pm *ProtocolManager) readRaftSnapshot() *raftpb.Snapshot {
   255  	snapshot, err := pm.snapshotter.Load()
   256  	if err != nil && err != snap.ErrNoSnapshot {
   257  		raft.Fatalf("error loading snapshot: %v", err)
   258  	}
   259  
   260  	return snapshot
   261  }
   262  
   263  func (pm *ProtocolManager) applyRaftSnapshot(raftSnapshot raftpb.Snapshot) {
   264  	log.Info("applying snapshot to raft storage")
   265  	if err := pm.raftStorage.ApplySnapshot(raftSnapshot); err != nil {
   266  		raft.Fatalf("failed to apply snapshot: %s", err)
   267  	}
   268  	snapshot := bytesToSnapshot(raftSnapshot.Data)
   269  
   270  	latestBlockHash := snapshot.headBlockHash
   271  
   272  	pm.updateClusterMembership(raftSnapshot.Metadata.ConfState, snapshot.addresses, snapshot.removedRaftIds)
   273  
   274  	preSyncHead := pm.blockchain.CurrentBlock()
   275  
   276  	if latestBlock := pm.blockchain.GetBlockByHash(latestBlockHash); latestBlock == nil {
   277  		pm.syncBlockchainUntil(latestBlockHash)
   278  		pm.logNewlyAcceptedTransactions(preSyncHead)
   279  
   280  		log.Info("Successfully extended chain", "hash", pm.blockchain.CurrentBlock().Hash())
   281  	} else {
   282  		// added for permissions changes to indicate node sync up has started
   283  		//TODO types.SetSyncStatus()
   284  		log.Info("blockchain is caught up; no need to synchronize")
   285  	}
   286  
   287  	snapMeta := raftSnapshot.Metadata
   288  	pm.confState = snapMeta.ConfState
   289  	pm.mu.Lock()
   290  	pm.snapshotIndex = snapMeta.Index
   291  	pm.mu.Unlock()
   292  }
   293  
   294  func (pm *ProtocolManager) syncBlockchainUntil(hash common.Hash) {
   295  	pm.mu.RLock()
   296  	peerMap := make(map[uint16]*raft.Peer, len(pm.peers))
   297  	for raftId, peer := range pm.peers {
   298  		peerMap[raftId] = peer
   299  	}
   300  	pm.mu.RUnlock()
   301  
   302  	for {
   303  		for peerId, peer := range peerMap {
   304  			log.Info("synchronizing with peer", "peer id", peerId, "hash", hash)
   305  
   306  			peerId := peer.P2pNode.ID().String()
   307  			peerIdPrefix := fmt.Sprintf("%x", peer.P2pNode.ID().Bytes()[:8])
   308  
   309  			if err := pm.downloader.Synchronise(peerIdPrefix, hash, big.NewInt(0), downloader.FullSync); err != nil {
   310  				log.Info("failed to synchronize with peer", "peer id", peerId)
   311  
   312  				time.Sleep(500 * time.Millisecond)
   313  			} else {
   314  				return
   315  			}
   316  		}
   317  	}
   318  }
   319  
   320  func (pm *ProtocolManager) logNewlyAcceptedTransactions(preSyncHead *types.Block) {
   321  	newHead := pm.blockchain.CurrentBlock()
   322  	numBlocks := newHead.NumberU64() - preSyncHead.NumberU64()
   323  	blocks := make([]*types.Block, numBlocks)
   324  	currBlock := newHead
   325  	blocksSeen := 0
   326  	for currBlock.Hash() != preSyncHead.Hash() {
   327  		blocks[int(numBlocks)-(1+blocksSeen)] = currBlock
   328  
   329  		blocksSeen += 1
   330  		currBlock = pm.blockchain.GetBlockByHash(currBlock.ParentHash())
   331  	}
   332  }