github.com/nats-io/nats-server/v2@v2.11.0-preview.2/server/raft_helpers_test.go (about)

     1  // Copyright 2023 The NATS Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  // Do not exlude this file with the !skip_js_tests since those helpers
    15  // are also used by MQTT.
    16  
    17  package server
    18  
    19  import (
    20  	"encoding/binary"
    21  	"fmt"
    22  	"math/rand"
    23  	"sync"
    24  	"testing"
    25  	"time"
    26  )
    27  
    28  type stateMachine interface {
    29  	server() *Server
    30  	node() RaftNode
    31  	// This will call forward as needed so can be called on any node.
    32  	propose(data []byte)
    33  	// When entries have been committed and can be applied.
    34  	applyEntry(ce *CommittedEntry)
    35  	// When a leader change happens.
    36  	leaderChange(isLeader bool)
    37  	// Stop the raft group.
    38  	stop()
    39  	// Restart
    40  	restart()
    41  }
    42  
    43  // Factory function needed for constructor.
    44  type smFactory func(s *Server, cfg *RaftConfig, node RaftNode) stateMachine
    45  
    46  type smGroup []stateMachine
    47  
    48  // Leader of the group.
    49  func (sg smGroup) leader() stateMachine {
    50  	for _, sm := range sg {
    51  		if sm.node().Leader() {
    52  			return sm
    53  		}
    54  	}
    55  	return nil
    56  }
    57  
    58  // Wait on a leader to be elected.
    59  func (sg smGroup) waitOnLeader() {
    60  	expires := time.Now().Add(10 * time.Second)
    61  	for time.Now().Before(expires) {
    62  		for _, sm := range sg {
    63  			if sm.node().Leader() {
    64  				return
    65  			}
    66  		}
    67  		time.Sleep(100 * time.Millisecond)
    68  	}
    69  }
    70  
    71  // Pick a random member.
    72  func (sg smGroup) randomMember() stateMachine {
    73  	return sg[rand.Intn(len(sg))]
    74  }
    75  
    76  // Return a non-leader
    77  func (sg smGroup) nonLeader() stateMachine {
    78  	for _, sm := range sg {
    79  		if !sm.node().Leader() {
    80  			return sm
    81  		}
    82  	}
    83  	return nil
    84  }
    85  
    86  // Create a raft group and place on numMembers servers at random.
    87  func (c *cluster) createRaftGroup(name string, numMembers int, smf smFactory) smGroup {
    88  	c.t.Helper()
    89  	if numMembers > len(c.servers) {
    90  		c.t.Fatalf("Members > Peers: %d vs  %d", numMembers, len(c.servers))
    91  	}
    92  	servers := append([]*Server{}, c.servers...)
    93  	rand.Shuffle(len(servers), func(i, j int) { servers[i], servers[j] = servers[j], servers[i] })
    94  	return c.createRaftGroupWithPeers(name, servers[:numMembers], smf)
    95  }
    96  
    97  func (c *cluster) createRaftGroupWithPeers(name string, servers []*Server, smf smFactory) smGroup {
    98  	c.t.Helper()
    99  
   100  	var sg smGroup
   101  	var peers []string
   102  
   103  	for _, s := range servers {
   104  		// generate peer names.
   105  		s.mu.RLock()
   106  		peers = append(peers, s.sys.shash)
   107  		s.mu.RUnlock()
   108  	}
   109  
   110  	for _, s := range servers {
   111  		fs, err := newFileStore(
   112  			FileStoreConfig{StoreDir: c.t.TempDir(), BlockSize: defaultMediumBlockSize, AsyncFlush: false, SyncInterval: 5 * time.Minute},
   113  			StreamConfig{Name: name, Storage: FileStorage},
   114  		)
   115  		require_NoError(c.t, err)
   116  		cfg := &RaftConfig{Name: name, Store: c.t.TempDir(), Log: fs}
   117  		s.bootstrapRaftNode(cfg, peers, true)
   118  		n, err := s.startRaftNode(globalAccountName, cfg, pprofLabels{})
   119  		require_NoError(c.t, err)
   120  		sm := smf(s, cfg, n)
   121  		sg = append(sg, sm)
   122  		go smLoop(sm)
   123  	}
   124  	return sg
   125  }
   126  
   127  // Driver program for the state machine.
   128  // Should be run in its own go routine.
   129  func smLoop(sm stateMachine) {
   130  	s, n := sm.server(), sm.node()
   131  	qch, lch, aq := n.QuitC(), n.LeadChangeC(), n.ApplyQ()
   132  
   133  	for {
   134  		select {
   135  		case <-s.quitCh:
   136  			return
   137  		case <-qch:
   138  			return
   139  		case <-aq.ch:
   140  			ces := aq.pop()
   141  			for _, ce := range ces {
   142  				sm.applyEntry(ce)
   143  			}
   144  			aq.recycle(&ces)
   145  
   146  		case isLeader := <-lch:
   147  			sm.leaderChange(isLeader)
   148  		}
   149  	}
   150  }
   151  
   152  // Simple implementation of a replicated state.
   153  // The adder state just sums up int64 values.
   154  type stateAdder struct {
   155  	sync.Mutex
   156  	s   *Server
   157  	n   RaftNode
   158  	cfg *RaftConfig
   159  	sum int64
   160  	lch chan bool
   161  }
   162  
   163  // Simple getters for server and the raft node.
   164  func (a *stateAdder) server() *Server {
   165  	a.Lock()
   166  	defer a.Unlock()
   167  	return a.s
   168  }
   169  func (a *stateAdder) node() RaftNode {
   170  	a.Lock()
   171  	defer a.Unlock()
   172  	return a.n
   173  }
   174  
   175  func (a *stateAdder) propose(data []byte) {
   176  	a.Lock()
   177  	defer a.Unlock()
   178  	a.n.ForwardProposal(data)
   179  }
   180  
   181  func (a *stateAdder) applyEntry(ce *CommittedEntry) {
   182  	a.Lock()
   183  	defer a.Unlock()
   184  	if ce == nil {
   185  		// This means initial state is done/replayed.
   186  		return
   187  	}
   188  	for _, e := range ce.Entries {
   189  		if e.Type == EntryNormal {
   190  			delta, _ := binary.Varint(e.Data)
   191  			a.sum += delta
   192  		} else if e.Type == EntrySnapshot {
   193  			a.sum, _ = binary.Varint(e.Data)
   194  		}
   195  	}
   196  	// Update applied.
   197  	a.n.Applied(ce.Index)
   198  }
   199  
   200  func (a *stateAdder) leaderChange(isLeader bool) {
   201  	select {
   202  	case a.lch <- isLeader:
   203  	default:
   204  	}
   205  }
   206  
   207  // Adder specific to change the total.
   208  func (a *stateAdder) proposeDelta(delta int64) {
   209  	data := make([]byte, binary.MaxVarintLen64)
   210  	n := binary.PutVarint(data, int64(delta))
   211  	a.propose(data[:n])
   212  }
   213  
   214  // Stop the group.
   215  func (a *stateAdder) stop() {
   216  	a.Lock()
   217  	defer a.Unlock()
   218  	a.n.Stop()
   219  }
   220  
   221  // Restart the group
   222  func (a *stateAdder) restart() {
   223  	a.Lock()
   224  	defer a.Unlock()
   225  
   226  	if a.n.State() != Closed {
   227  		return
   228  	}
   229  
   230  	// The filestore is stopped as well, so need to extract the parts to recreate it.
   231  	rn := a.n.(*raft)
   232  	fs := rn.wal.(*fileStore)
   233  
   234  	var err error
   235  	a.cfg.Log, err = newFileStore(fs.fcfg, fs.cfg.StreamConfig)
   236  	if err != nil {
   237  		panic(err)
   238  	}
   239  	a.n, err = a.s.startRaftNode(globalAccountName, a.cfg, pprofLabels{})
   240  	if err != nil {
   241  		panic(err)
   242  	}
   243  	// Finally restart the driver.
   244  	go smLoop(a)
   245  }
   246  
   247  // Total for the adder state machine.
   248  func (a *stateAdder) total() int64 {
   249  	a.Lock()
   250  	defer a.Unlock()
   251  	return a.sum
   252  }
   253  
   254  // Install a snapshot.
   255  func (a *stateAdder) snapshot(t *testing.T) {
   256  	a.Lock()
   257  	defer a.Unlock()
   258  	data := make([]byte, binary.MaxVarintLen64)
   259  	n := binary.PutVarint(data, a.sum)
   260  	snap := data[:n]
   261  	require_NoError(t, a.n.InstallSnapshot(snap))
   262  }
   263  
   264  // Helper to wait for a certain state.
   265  func (rg smGroup) waitOnTotal(t *testing.T, expected int64) {
   266  	t.Helper()
   267  	checkFor(t, 20*time.Second, 200*time.Millisecond, func() error {
   268  		for _, sm := range rg {
   269  			asm := sm.(*stateAdder)
   270  			if total := asm.total(); total != expected {
   271  				return fmt.Errorf("Adder on %v has wrong total: %d vs %d",
   272  					asm.server(), total, expected)
   273  			}
   274  		}
   275  		return nil
   276  	})
   277  }
   278  
   279  // Factory function.
   280  func newStateAdder(s *Server, cfg *RaftConfig, n RaftNode) stateMachine {
   281  	return &stateAdder{s: s, n: n, cfg: cfg, lch: make(chan bool, 1)}
   282  }