get.pme.sh/pnats@v0.0.0-20240304004023-26bb5a137ed0/server/raft_helpers_test.go (about)

     1  // Copyright 2023 The NATS Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  // Do not exlude this file with the !skip_js_tests since those helpers
    15  // are also used by MQTT.
    16  
    17  package server
    18  
    19  import (
    20  	"encoding/binary"
    21  	"fmt"
    22  	"math/rand"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  )
    27  
    28  type stateMachine interface {
    29  	server() *Server
    30  	node() RaftNode
    31  	// This will call forward as needed so can be called on any node.
    32  	propose(data []byte)
    33  	// When entries have been committed and can be applied.
    34  	applyEntry(ce *CommittedEntry)
    35  	// When a leader change happens.
    36  	leaderChange(isLeader bool)
    37  	// Stop the raft group.
    38  	stop()
    39  	// Restart
    40  	restart()
    41  }
    42  
    43  // Factory function needed for constructor.
    44  type smFactory func(s *Server, cfg *RaftConfig, node RaftNode) stateMachine
    45  
    46  type smGroup []stateMachine
    47  
    48  // Leader of the group.
    49  func (sg smGroup) leader() stateMachine {
    50  	for _, sm := range sg {
    51  		if sm.node().Leader() {
    52  			return sm
    53  		}
    54  	}
    55  	return nil
    56  }
    57  
    58  // Wait on a leader to be elected.
    59  func (sg smGroup) waitOnLeader() {
    60  	expires := time.Now().Add(10 * time.Second)
    61  	for time.Now().Before(expires) {
    62  		for _, sm := range sg {
    63  			if sm.node().Leader() {
    64  				return
    65  			}
    66  		}
    67  		time.Sleep(100 * time.Millisecond)
    68  	}
    69  }
    70  
    71  // Pick a random member.
    72  func (sg smGroup) randomMember() stateMachine {
    73  	return sg[rand.Intn(len(sg))]
    74  }
    75  
    76  // Return a non-leader
    77  func (sg smGroup) nonLeader() stateMachine {
    78  	for _, sm := range sg {
    79  		if !sm.node().Leader() {
    80  			return sm
    81  		}
    82  	}
    83  	return nil
    84  }
    85  
    86  // Create a raft group and place on numMembers servers at random.
    87  func (c *cluster) createRaftGroup(name string, numMembers int, smf smFactory) smGroup {
    88  	c.t.Helper()
    89  	if numMembers > len(c.servers) {
    90  		c.t.Fatalf("Members > Peers: %d vs  %d", numMembers, len(c.servers))
    91  	}
    92  	servers := append([]*Server{}, c.servers...)
    93  	rand.Shuffle(len(servers), func(i, j int) { servers[i], servers[j] = servers[j], servers[i] })
    94  	return c.createRaftGroupWithPeers(name, servers[:numMembers], smf)
    95  }
    96  
    97  func (c *cluster) createRaftGroupWithPeers(name string, servers []*Server, smf smFactory) smGroup {
    98  	c.t.Helper()
    99  
   100  	var sg smGroup
   101  	var peers []string
   102  
   103  	for _, s := range servers {
   104  		// generate peer names.
   105  		s.mu.RLock()
   106  		peers = append(peers, s.sys.shash)
   107  		s.mu.RUnlock()
   108  	}
   109  
   110  	for _, s := range servers {
   111  		fs, err := newFileStore(
   112  			FileStoreConfig{StoreDir: c.t.TempDir(), BlockSize: defaultMediumBlockSize, AsyncFlush: false, SyncInterval: 5 * time.Minute},
   113  			StreamConfig{Name: name, Storage: FileStorage},
   114  		)
   115  		require_NoError(c.t, err)
   116  		cfg := &RaftConfig{Name: name, Store: c.t.TempDir(), Log: fs}
   117  		s.bootstrapRaftNode(cfg, peers, true)
   118  		n, err := s.startRaftNode(globalAccountName, cfg, pprofLabels{})
   119  		require_NoError(c.t, err)
   120  		sm := smf(s, cfg, n)
   121  		sg = append(sg, sm)
   122  		go smLoop(sm)
   123  	}
   124  	return sg
   125  }
   126  
   127  // Driver program for the state machine.
   128  // Should be run in its own go routine.
   129  func smLoop(sm stateMachine) {
   130  	s, n := sm.server(), sm.node()
   131  	qch, lch, aq := n.QuitC(), n.LeadChangeC(), n.ApplyQ()
   132  
   133  	for {
   134  		select {
   135  		case <-s.quitCh:
   136  			return
   137  		case <-qch:
   138  			return
   139  		case <-aq.ch:
   140  			ces := aq.pop()
   141  			for _, ce := range ces {
   142  				sm.applyEntry(ce)
   143  			}
   144  			aq.recycle(&ces)
   145  
   146  		case isLeader := <-lch:
   147  			sm.leaderChange(isLeader)
   148  		}
   149  	}
   150  }
   151  
   152  // Simple implementation of a replicated state.
   153  // The adder state just sums up int64 values.
   154  type stateAdder struct {
   155  	sync.Mutex
   156  	s   *Server
   157  	n   RaftNode
   158  	cfg *RaftConfig
   159  	sum int64
   160  }
   161  
   162  // Simple getters for server and the raft node.
   163  func (a *stateAdder) server() *Server {
   164  	a.Lock()
   165  	defer a.Unlock()
   166  	return a.s
   167  }
   168  func (a *stateAdder) node() RaftNode {
   169  	a.Lock()
   170  	defer a.Unlock()
   171  	return a.n
   172  }
   173  
   174  func (a *stateAdder) propose(data []byte) {
   175  	a.Lock()
   176  	defer a.Unlock()
   177  	a.n.ForwardProposal(data)
   178  }
   179  
   180  func (a *stateAdder) applyEntry(ce *CommittedEntry) {
   181  	a.Lock()
   182  	defer a.Unlock()
   183  	if ce == nil {
   184  		// This means initial state is done/replayed.
   185  		return
   186  	}
   187  	for _, e := range ce.Entries {
   188  		if e.Type == EntryNormal {
   189  			delta, _ := binary.Varint(e.Data)
   190  			a.sum += delta
   191  		} else if e.Type == EntrySnapshot {
   192  			a.sum, _ = binary.Varint(e.Data)
   193  		}
   194  	}
   195  	// Update applied.
   196  	a.n.Applied(ce.Index)
   197  }
   198  
   199  func (a *stateAdder) leaderChange(isLeader bool) {}
   200  
   201  // Adder specific to change the total.
   202  func (a *stateAdder) proposeDelta(delta int64) {
   203  	data := make([]byte, binary.MaxVarintLen64)
   204  	n := binary.PutVarint(data, int64(delta))
   205  	a.propose(data[:n])
   206  }
   207  
   208  // Stop the group.
   209  func (a *stateAdder) stop() {
   210  	a.Lock()
   211  	defer a.Unlock()
   212  	a.n.Stop()
   213  }
   214  
   215  // Restart the group
   216  func (a *stateAdder) restart() {
   217  	a.Lock()
   218  	defer a.Unlock()
   219  
   220  	if a.n.State() != Closed {
   221  		return
   222  	}
   223  
   224  	// The filestore is stopped as well, so need to extract the parts to recreate it.
   225  	rn := a.n.(*raft)
   226  	fs := rn.wal.(*fileStore)
   227  
   228  	var err error
   229  	a.cfg.Log, err = newFileStore(fs.fcfg, fs.cfg.StreamConfig)
   230  	if err != nil {
   231  		panic(err)
   232  	}
   233  	a.n, err = a.s.startRaftNode(globalAccountName, a.cfg, pprofLabels{})
   234  	if err != nil {
   235  		panic(err)
   236  	}
   237  	// Finally restart the driver.
   238  	go smLoop(a)
   239  }
   240  
   241  // Total for the adder state machine.
   242  func (a *stateAdder) total() int64 {
   243  	a.Lock()
   244  	defer a.Unlock()
   245  	return a.sum
   246  }
   247  
   248  // Install a snapshot.
   249  func (a *stateAdder) snapshot(t *testing.T) {
   250  	a.Lock()
   251  	defer a.Unlock()
   252  	data := make([]byte, binary.MaxVarintLen64)
   253  	n := binary.PutVarint(data, a.sum)
   254  	snap := data[:n]
   255  	require_NoError(t, a.n.InstallSnapshot(snap))
   256  }
   257  
   258  // Helper to wait for a certain state.
   259  func (rg smGroup) waitOnTotal(t *testing.T, expected int64) {
   260  	t.Helper()
   261  	checkFor(t, 20*time.Second, 200*time.Millisecond, func() error {
   262  		for _, sm := range rg {
   263  			asm := sm.(*stateAdder)
   264  			if total := asm.total(); total != expected {
   265  				return fmt.Errorf("Adder on %v has wrong total: %d vs %d",
   266  					asm.server(), total, expected)
   267  			}
   268  		}
   269  		return nil
   270  	})
   271  }
   272  
   273  // Factory function.
   274  func newStateAdder(s *Server, cfg *RaftConfig, n RaftNode) stateMachine {
   275  	return &stateAdder{s: s, n: n, cfg: cfg}
   276  }