go.etcd.io/etcd@v3.3.27+incompatible/contrib/raftexample/raft.go (about)

     1  // Copyright 2015 The etcd Authors
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     http://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  
    15  package main
    16  
    17  import (
    18  	"context"
    19  	"fmt"
    20  	"log"
    21  	"net/http"
    22  	"net/url"
    23  	"os"
    24  	"strconv"
    25  	"time"
    26  
    27  	"github.com/coreos/etcd/etcdserver/stats"
    28  	"github.com/coreos/etcd/pkg/fileutil"
    29  	"github.com/coreos/etcd/pkg/types"
    30  	"github.com/coreos/etcd/raft"
    31  	"github.com/coreos/etcd/raft/raftpb"
    32  	"github.com/coreos/etcd/rafthttp"
    33  	"github.com/coreos/etcd/snap"
    34  	"github.com/coreos/etcd/wal"
    35  	"github.com/coreos/etcd/wal/walpb"
    36  )
    37  
    38  // A key-value stream backed by raft
    39  type raftNode struct {
    40  	proposeC    <-chan string            // proposed messages (k,v)
    41  	confChangeC <-chan raftpb.ConfChange // proposed cluster config changes
    42  	commitC     chan<- *string           // entries committed to log (k,v)
    43  	errorC      chan<- error             // errors from raft session
    44  
    45  	id          int      // client ID for raft session
    46  	peers       []string // raft peer URLs
    47  	join        bool     // node is joining an existing cluster
    48  	waldir      string   // path to WAL directory
    49  	snapdir     string   // path to snapshot directory
    50  	getSnapshot func() ([]byte, error)
    51  	lastIndex   uint64 // index of log at start
    52  
    53  	confState     raftpb.ConfState
    54  	snapshotIndex uint64
    55  	appliedIndex  uint64
    56  
    57  	// raft backing for the commit/error channel
    58  	node        raft.Node
    59  	raftStorage *raft.MemoryStorage
    60  	wal         *wal.WAL
    61  
    62  	snapshotter      *snap.Snapshotter
    63  	snapshotterReady chan *snap.Snapshotter // signals when snapshotter is ready
    64  
    65  	snapCount uint64
    66  	transport *rafthttp.Transport
    67  	stopc     chan struct{} // signals proposal channel closed
    68  	httpstopc chan struct{} // signals http server to shutdown
    69  	httpdonec chan struct{} // signals http server shutdown complete
    70  }
    71  
    72  var defaultSnapCount uint64 = 10000
    73  
    74  // newRaftNode initiates a raft instance and returns a committed log entry
    75  // channel and error channel. Proposals for log updates are sent over the
    76  // provided the proposal channel. All log entries are replayed over the
    77  // commit channel, followed by a nil message (to indicate the channel is
    78  // current), then new log entries. To shutdown, close proposeC and read errorC.
    79  func newRaftNode(id int, peers []string, join bool, getSnapshot func() ([]byte, error), proposeC <-chan string,
    80  	confChangeC <-chan raftpb.ConfChange) (<-chan *string, <-chan error, <-chan *snap.Snapshotter) {
    81  
    82  	commitC := make(chan *string)
    83  	errorC := make(chan error)
    84  
    85  	rc := &raftNode{
    86  		proposeC:    proposeC,
    87  		confChangeC: confChangeC,
    88  		commitC:     commitC,
    89  		errorC:      errorC,
    90  		id:          id,
    91  		peers:       peers,
    92  		join:        join,
    93  		waldir:      fmt.Sprintf("raftexample-%d", id),
    94  		snapdir:     fmt.Sprintf("raftexample-%d-snap", id),
    95  		getSnapshot: getSnapshot,
    96  		snapCount:   defaultSnapCount,
    97  		stopc:       make(chan struct{}),
    98  		httpstopc:   make(chan struct{}),
    99  		httpdonec:   make(chan struct{}),
   100  
   101  		snapshotterReady: make(chan *snap.Snapshotter, 1),
   102  		// rest of structure populated after WAL replay
   103  	}
   104  	go rc.startRaft()
   105  	return commitC, errorC, rc.snapshotterReady
   106  }
   107  
   108  func (rc *raftNode) saveSnap(snap raftpb.Snapshot) error {
   109  	// must save the snapshot index to the WAL before saving the
   110  	// snapshot to maintain the invariant that we only Open the
   111  	// wal at previously-saved snapshot indexes.
   112  	walSnap := walpb.Snapshot{
   113  		Index: snap.Metadata.Index,
   114  		Term:  snap.Metadata.Term,
   115  	}
   116  	if err := rc.wal.SaveSnapshot(walSnap); err != nil {
   117  		return err
   118  	}
   119  	if err := rc.snapshotter.SaveSnap(snap); err != nil {
   120  		return err
   121  	}
   122  	return rc.wal.ReleaseLockTo(snap.Metadata.Index)
   123  }
   124  
   125  func (rc *raftNode) entriesToApply(ents []raftpb.Entry) (nents []raftpb.Entry) {
   126  	if len(ents) == 0 {
   127  		return
   128  	}
   129  	firstIdx := ents[0].Index
   130  	if firstIdx > rc.appliedIndex+1 {
   131  		log.Fatalf("first index of committed entry[%d] should <= progress.appliedIndex[%d] 1", firstIdx, rc.appliedIndex)
   132  	}
   133  	if rc.appliedIndex-firstIdx+1 < uint64(len(ents)) {
   134  		nents = ents[rc.appliedIndex-firstIdx+1:]
   135  	}
   136  	return nents
   137  }
   138  
   139  // publishEntries writes committed log entries to commit channel and returns
   140  // whether all entries could be published.
   141  func (rc *raftNode) publishEntries(ents []raftpb.Entry) bool {
   142  	for i := range ents {
   143  		switch ents[i].Type {
   144  		case raftpb.EntryNormal:
   145  			if len(ents[i].Data) == 0 {
   146  				// ignore empty messages
   147  				break
   148  			}
   149  			s := string(ents[i].Data)
   150  			select {
   151  			case rc.commitC <- &s:
   152  			case <-rc.stopc:
   153  				return false
   154  			}
   155  
   156  		case raftpb.EntryConfChange:
   157  			var cc raftpb.ConfChange
   158  			cc.Unmarshal(ents[i].Data)
   159  			rc.confState = *rc.node.ApplyConfChange(cc)
   160  			switch cc.Type {
   161  			case raftpb.ConfChangeAddNode:
   162  				if len(cc.Context) > 0 {
   163  					rc.transport.AddPeer(types.ID(cc.NodeID), []string{string(cc.Context)})
   164  				}
   165  			case raftpb.ConfChangeRemoveNode:
   166  				if cc.NodeID == uint64(rc.id) {
   167  					log.Println("I've been removed from the cluster! Shutting down.")
   168  					return false
   169  				}
   170  				rc.transport.RemovePeer(types.ID(cc.NodeID))
   171  			}
   172  		}
   173  
   174  		// after commit, update appliedIndex
   175  		rc.appliedIndex = ents[i].Index
   176  
   177  		// special nil commit to signal replay has finished
   178  		if ents[i].Index == rc.lastIndex {
   179  			select {
   180  			case rc.commitC <- nil:
   181  			case <-rc.stopc:
   182  				return false
   183  			}
   184  		}
   185  	}
   186  	return true
   187  }
   188  
   189  func (rc *raftNode) loadSnapshot() *raftpb.Snapshot {
   190  	snapshot, err := rc.snapshotter.Load()
   191  	if err != nil && err != snap.ErrNoSnapshot {
   192  		log.Fatalf("raftexample: error loading snapshot (%v)", err)
   193  	}
   194  	return snapshot
   195  }
   196  
   197  // openWAL returns a WAL ready for reading.
   198  func (rc *raftNode) openWAL(snapshot *raftpb.Snapshot) *wal.WAL {
   199  	if !wal.Exist(rc.waldir) {
   200  		if err := os.Mkdir(rc.waldir, 0750); err != nil {
   201  			log.Fatalf("raftexample: cannot create dir for wal (%v)", err)
   202  		}
   203  
   204  		w, err := wal.Create(rc.waldir, nil)
   205  		if err != nil {
   206  			log.Fatalf("raftexample: create wal error (%v)", err)
   207  		}
   208  		w.Close()
   209  	}
   210  
   211  	walsnap := walpb.Snapshot{}
   212  	if snapshot != nil {
   213  		walsnap.Index, walsnap.Term = snapshot.Metadata.Index, snapshot.Metadata.Term
   214  	}
   215  	log.Printf("loading WAL at term %d and index %d", walsnap.Term, walsnap.Index)
   216  	w, err := wal.Open(rc.waldir, walsnap)
   217  	if err != nil {
   218  		log.Fatalf("raftexample: error loading wal (%v)", err)
   219  	}
   220  
   221  	return w
   222  }
   223  
   224  // replayWAL replays WAL entries into the raft instance.
   225  func (rc *raftNode) replayWAL() *wal.WAL {
   226  	log.Printf("replaying WAL of member %d", rc.id)
   227  	snapshot := rc.loadSnapshot()
   228  	w := rc.openWAL(snapshot)
   229  	_, st, ents, err := w.ReadAll()
   230  	if err != nil {
   231  		log.Fatalf("raftexample: failed to read WAL (%v)", err)
   232  	}
   233  	rc.raftStorage = raft.NewMemoryStorage()
   234  	if snapshot != nil {
   235  		rc.raftStorage.ApplySnapshot(*snapshot)
   236  	}
   237  	rc.raftStorage.SetHardState(st)
   238  
   239  	// append to storage so raft starts at the right place in log
   240  	rc.raftStorage.Append(ents)
   241  	// send nil once lastIndex is published so client knows commit channel is current
   242  	if len(ents) > 0 {
   243  		rc.lastIndex = ents[len(ents)-1].Index
   244  	} else {
   245  		rc.commitC <- nil
   246  	}
   247  	return w
   248  }
   249  
   250  func (rc *raftNode) writeError(err error) {
   251  	rc.stopHTTP()
   252  	close(rc.commitC)
   253  	rc.errorC <- err
   254  	close(rc.errorC)
   255  	rc.node.Stop()
   256  }
   257  
   258  func (rc *raftNode) startRaft() {
   259  	if !fileutil.Exist(rc.snapdir) {
   260  		if err := os.Mkdir(rc.snapdir, 0750); err != nil {
   261  			log.Fatalf("raftexample: cannot create dir for snapshot (%v)", err)
   262  		}
   263  	}
   264  	rc.snapshotter = snap.New(rc.snapdir)
   265  	rc.snapshotterReady <- rc.snapshotter
   266  
   267  	oldwal := wal.Exist(rc.waldir)
   268  	rc.wal = rc.replayWAL()
   269  
   270  	rpeers := make([]raft.Peer, len(rc.peers))
   271  	for i := range rpeers {
   272  		rpeers[i] = raft.Peer{ID: uint64(i + 1)}
   273  	}
   274  	c := &raft.Config{
   275  		ID:              uint64(rc.id),
   276  		ElectionTick:    10,
   277  		HeartbeatTick:   1,
   278  		Storage:         rc.raftStorage,
   279  		MaxSizePerMsg:   1024 * 1024,
   280  		MaxInflightMsgs: 256,
   281  	}
   282  
   283  	if oldwal {
   284  		rc.node = raft.RestartNode(c)
   285  	} else {
   286  		startPeers := rpeers
   287  		if rc.join {
   288  			startPeers = nil
   289  		}
   290  		rc.node = raft.StartNode(c, startPeers)
   291  	}
   292  
   293  	rc.transport = &rafthttp.Transport{
   294  		ID:          types.ID(rc.id),
   295  		ClusterID:   0x1000,
   296  		Raft:        rc,
   297  		ServerStats: stats.NewServerStats("", ""),
   298  		LeaderStats: stats.NewLeaderStats(strconv.Itoa(rc.id)),
   299  		ErrorC:      make(chan error),
   300  	}
   301  
   302  	rc.transport.Start()
   303  	for i := range rc.peers {
   304  		if i+1 != rc.id {
   305  			rc.transport.AddPeer(types.ID(i+1), []string{rc.peers[i]})
   306  		}
   307  	}
   308  
   309  	go rc.serveRaft()
   310  	go rc.serveChannels()
   311  }
   312  
   313  // stop closes http, closes all channels, and stops raft.
   314  func (rc *raftNode) stop() {
   315  	rc.stopHTTP()
   316  	close(rc.commitC)
   317  	close(rc.errorC)
   318  	rc.node.Stop()
   319  }
   320  
   321  func (rc *raftNode) stopHTTP() {
   322  	rc.transport.Stop()
   323  	close(rc.httpstopc)
   324  	<-rc.httpdonec
   325  }
   326  
   327  func (rc *raftNode) publishSnapshot(snapshotToSave raftpb.Snapshot) {
   328  	if raft.IsEmptySnap(snapshotToSave) {
   329  		return
   330  	}
   331  
   332  	log.Printf("publishing snapshot at index %d", rc.snapshotIndex)
   333  	defer log.Printf("finished publishing snapshot at index %d", rc.snapshotIndex)
   334  
   335  	if snapshotToSave.Metadata.Index <= rc.appliedIndex {
   336  		log.Fatalf("snapshot index [%d] should > progress.appliedIndex [%d] + 1", snapshotToSave.Metadata.Index, rc.appliedIndex)
   337  	}
   338  	rc.commitC <- nil // trigger kvstore to load snapshot
   339  
   340  	rc.confState = snapshotToSave.Metadata.ConfState
   341  	rc.snapshotIndex = snapshotToSave.Metadata.Index
   342  	rc.appliedIndex = snapshotToSave.Metadata.Index
   343  }
   344  
   345  var snapshotCatchUpEntriesN uint64 = 10000
   346  
   347  func (rc *raftNode) maybeTriggerSnapshot() {
   348  	if rc.appliedIndex-rc.snapshotIndex <= rc.snapCount {
   349  		return
   350  	}
   351  
   352  	log.Printf("start snapshot [applied index: %d | last snapshot index: %d]", rc.appliedIndex, rc.snapshotIndex)
   353  	data, err := rc.getSnapshot()
   354  	if err != nil {
   355  		log.Panic(err)
   356  	}
   357  	snap, err := rc.raftStorage.CreateSnapshot(rc.appliedIndex, &rc.confState, data)
   358  	if err != nil {
   359  		panic(err)
   360  	}
   361  	if err := rc.saveSnap(snap); err != nil {
   362  		panic(err)
   363  	}
   364  
   365  	compactIndex := uint64(1)
   366  	if rc.appliedIndex > snapshotCatchUpEntriesN {
   367  		compactIndex = rc.appliedIndex - snapshotCatchUpEntriesN
   368  	}
   369  	if err := rc.raftStorage.Compact(compactIndex); err != nil {
   370  		panic(err)
   371  	}
   372  
   373  	log.Printf("compacted log at index %d", compactIndex)
   374  	rc.snapshotIndex = rc.appliedIndex
   375  }
   376  
   377  func (rc *raftNode) serveChannels() {
   378  	snap, err := rc.raftStorage.Snapshot()
   379  	if err != nil {
   380  		panic(err)
   381  	}
   382  	rc.confState = snap.Metadata.ConfState
   383  	rc.snapshotIndex = snap.Metadata.Index
   384  	rc.appliedIndex = snap.Metadata.Index
   385  
   386  	defer rc.wal.Close()
   387  
   388  	ticker := time.NewTicker(100 * time.Millisecond)
   389  	defer ticker.Stop()
   390  
   391  	// send proposals over raft
   392  	go func() {
   393  		var confChangeCount uint64 = 0
   394  
   395  		for rc.proposeC != nil && rc.confChangeC != nil {
   396  			select {
   397  			case prop, ok := <-rc.proposeC:
   398  				if !ok {
   399  					rc.proposeC = nil
   400  				} else {
   401  					// blocks until accepted by raft state machine
   402  					rc.node.Propose(context.TODO(), []byte(prop))
   403  				}
   404  
   405  			case cc, ok := <-rc.confChangeC:
   406  				if !ok {
   407  					rc.confChangeC = nil
   408  				} else {
   409  					confChangeCount += 1
   410  					cc.ID = confChangeCount
   411  					rc.node.ProposeConfChange(context.TODO(), cc)
   412  				}
   413  			}
   414  		}
   415  		// client closed channel; shutdown raft if not already
   416  		close(rc.stopc)
   417  	}()
   418  
   419  	// event loop on raft state machine updates
   420  	for {
   421  		select {
   422  		case <-ticker.C:
   423  			rc.node.Tick()
   424  
   425  		// store raft entries to wal, then publish over commit channel
   426  		case rd := <-rc.node.Ready():
   427  			rc.wal.Save(rd.HardState, rd.Entries)
   428  			if !raft.IsEmptySnap(rd.Snapshot) {
   429  				rc.saveSnap(rd.Snapshot)
   430  				rc.raftStorage.ApplySnapshot(rd.Snapshot)
   431  				rc.publishSnapshot(rd.Snapshot)
   432  			}
   433  			rc.raftStorage.Append(rd.Entries)
   434  			rc.transport.Send(rd.Messages)
   435  			if ok := rc.publishEntries(rc.entriesToApply(rd.CommittedEntries)); !ok {
   436  				rc.stop()
   437  				return
   438  			}
   439  			rc.maybeTriggerSnapshot()
   440  			rc.node.Advance()
   441  
   442  		case err := <-rc.transport.ErrorC:
   443  			rc.writeError(err)
   444  			return
   445  
   446  		case <-rc.stopc:
   447  			rc.stop()
   448  			return
   449  		}
   450  	}
   451  }
   452  
   453  func (rc *raftNode) serveRaft() {
   454  	url, err := url.Parse(rc.peers[rc.id-1])
   455  	if err != nil {
   456  		log.Fatalf("raftexample: Failed parsing URL (%v)", err)
   457  	}
   458  
   459  	ln, err := newStoppableListener(url.Host, rc.httpstopc)
   460  	if err != nil {
   461  		log.Fatalf("raftexample: Failed to listen rafthttp (%v)", err)
   462  	}
   463  
   464  	err = (&http.Server{Handler: rc.transport.Handler()}).Serve(ln)
   465  	select {
   466  	case <-rc.httpstopc:
   467  	default:
   468  		log.Fatalf("raftexample: Failed to serve rafthttp (%v)", err)
   469  	}
   470  	close(rc.httpdonec)
   471  }
   472  
   473  func (rc *raftNode) Process(ctx context.Context, m raftpb.Message) error {
   474  	return rc.node.Step(ctx, m)
   475  }
   476  func (rc *raftNode) IsIDRemoved(id uint64) bool                           { return false }
   477  func (rc *raftNode) ReportUnreachable(id uint64)                          {}
   478  func (rc *raftNode) ReportSnapshot(id uint64, status raft.SnapshotStatus) {}