github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/election/streams/election_test.go (about)

     1  // Copyright (c) 2021-2023, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package election
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"os"
    11  	"sync"
    12  	"testing"
    13  	"time"
    14  
    15  	"github.com/nats-io/nats-server/v2/server"
    16  	"github.com/nats-io/nats.go"
    17  	. "github.com/onsi/ginkgo/v2"
    18  	. "github.com/onsi/gomega"
    19  )
    20  
    21  func TestLeader(t *testing.T) {
    22  	RegisterFailHandler(Fail)
    23  	RunSpecs(t, "Providers/Election/Streams")
    24  }
    25  
    26  var _ = Describe("Choria KV Leader Election", func() {
    27  	var (
    28  		srv      *server.Server
    29  		nc       *nats.Conn
    30  		js       nats.KeyValueManager
    31  		kv       nats.KeyValue
    32  		err      error
    33  		debugger func(f string, a ...any)
    34  	)
    35  
    36  	BeforeEach(func() {
    37  		skipValidate = false
    38  		srv, nc = startJSServer(GinkgoT())
    39  		js, err = nc.JetStream()
    40  		Expect(err).ToNot(HaveOccurred())
    41  
    42  		kv, err = js.CreateKeyValue(&nats.KeyValueConfig{
    43  			Bucket: "LEADER_ELECTION",
    44  			TTL:    500 * time.Millisecond,
    45  		})
    46  		Expect(err).ToNot(HaveOccurred())
    47  		debugger = func(f string, a ...any) {
    48  			fmt.Fprintf(GinkgoWriter, fmt.Sprintf("%s: %s\n", time.Now(), f), a...)
    49  		}
    50  	})
    51  
    52  	AfterEach(func() {
    53  		nc.Close()
    54  		srv.Shutdown()
    55  		srv.WaitForShutdown()
    56  		if srv.StoreDir() != "" {
    57  			os.RemoveAll(srv.StoreDir())
    58  		}
    59  	})
    60  
    61  	Describe("Election", func() {
    62  		It("Should validate the TTL", func() {
    63  			kv, err := js.CreateKeyValue(&nats.KeyValueConfig{
    64  				Bucket: "LE",
    65  				TTL:    100 * time.Millisecond,
    66  			})
    67  			Expect(err).ToNot(HaveOccurred())
    68  
    69  			election, err := NewElection("test", "test.key", kv)
    70  			Expect(err).ToNot(HaveOccurred())
    71  			err = election.Start(context.Background())
    72  			Expect(err).To(MatchError("bucket TTL should be 1 second or more"))
    73  
    74  			err = js.DeleteKeyValue("LE")
    75  			Expect(err).ToNot(HaveOccurred())
    76  
    77  			kv, err = js.CreateKeyValue(&nats.KeyValueConfig{
    78  				Bucket: "LE",
    79  				TTL:    24 * time.Hour,
    80  			})
    81  			Expect(err).ToNot(HaveOccurred())
    82  
    83  			election, err = NewElection("test", "test.key", kv)
    84  			Expect(err).ToNot(HaveOccurred())
    85  			err = election.Start(context.Background())
    86  			Expect(err).To(MatchError("bucket TTL should be less than or equal to 1 hour"))
    87  		})
    88  
    89  		It("Should allow 5 second TTLs", func() {
    90  			kv, err := js.CreateKeyValue(&nats.KeyValueConfig{
    91  				Bucket: "LE",
    92  				TTL:    5 * time.Second,
    93  			})
    94  			Expect(err).ToNot(HaveOccurred())
    95  
    96  			_, err = NewElection("test", "test.key", kv)
    97  			Expect(err).ToNot(HaveOccurred())
    98  		})
    99  
   100  		It("Should correctly manage leadership", func() {
   101  			var (
   102  				wins      = 0
   103  				lost      = 0
   104  				active    = make(map[string]struct{})
   105  				maxActive = 0
   106  				other     = 0
   107  				wg        = &sync.WaitGroup{}
   108  				mu        = sync.Mutex{}
   109  			)
   110  
   111  			skipValidate = true
   112  
   113  			ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
   114  			defer cancel()
   115  
   116  			worker := func(wg *sync.WaitGroup, i int, key string) {
   117  				defer wg.Done()
   118  
   119  				name := fmt.Sprintf("member %d", i)
   120  
   121  				winCb := func() {
   122  					mu.Lock()
   123  					wins++
   124  					active[name] = struct{}{}
   125  					act := len(active)
   126  					if act > maxActive {
   127  						maxActive = act
   128  					}
   129  					mu.Unlock()
   130  
   131  					debugger("%d became leader with %d active leaders", i, act)
   132  				}
   133  
   134  				lostCb := func() {
   135  					mu.Lock()
   136  					lost++
   137  					delete(active, name)
   138  					mu.Unlock()
   139  					debugger("%d lost leadership", i)
   140  				}
   141  
   142  				elect, err := NewElection(name, key, kv,
   143  					OnWon(winCb),
   144  					OnLost(lostCb),
   145  					WithDebug(debugger))
   146  				Expect(err).ToNot(HaveOccurred())
   147  
   148  				err = elect.Start(ctx)
   149  				Expect(err).ToNot(HaveOccurred())
   150  			}
   151  
   152  			for i := 0; i < 10; i++ {
   153  				wg.Add(1)
   154  				go worker(wg, i, "election")
   155  			}
   156  
   157  			// make sure another election is not interrupted
   158  			otherWorker := func(wg *sync.WaitGroup, i int) {
   159  				defer wg.Done()
   160  
   161  				elect, err := NewElection(fmt.Sprintf("other %d", i), "other", kv,
   162  					OnWon(func() {
   163  						mu.Lock()
   164  						debugger("other %d gained leader", i)
   165  						other++
   166  						mu.Unlock()
   167  					}),
   168  					OnLost(func() {
   169  						defer GinkgoRecover()
   170  						debugger("other %d lost leader", i)
   171  						Fail(fmt.Sprintf("Other %d election was lost", i))
   172  					}))
   173  				Expect(err).ToNot(HaveOccurred())
   174  
   175  				err = elect.Start(ctx)
   176  				Expect(err).ToNot(HaveOccurred())
   177  			}
   178  			wg.Add(2)
   179  			go otherWorker(wg, 1)
   180  			go otherWorker(wg, 2)
   181  
   182  			// test failure scenarios, either the key gets deleted to allow a Create() to succeed
   183  			// or it gets corrupted by putting a item in the key that would therefore change the sequence
   184  			// causing next campaign by the leader to fail. The leader will stand-down, all campaigns will
   185  			// fail until the corruption is removed by the MaxAge limit
   186  			kills := 0
   187  			for {
   188  				if ctxSleep(ctx, 400*time.Millisecond) != nil {
   189  					break
   190  				}
   191  
   192  				mu.Lock()
   193  				act := len(active)
   194  				mu.Unlock()
   195  
   196  				// only corrupt when there is a leader
   197  				if act == 0 {
   198  					continue
   199  				}
   200  
   201  				kills++
   202  				if kills%3 == 0 {
   203  					debugger("deleting key")
   204  					Expect(kv.Delete("election")).ToNot(HaveOccurred())
   205  				} else {
   206  					debugger("corrupting key")
   207  					_, err := kv.Put("election", nil)
   208  					Expect(err).ToNot(HaveOccurred())
   209  				}
   210  			}
   211  
   212  			wg.Wait()
   213  
   214  			mu.Lock()
   215  			defer mu.Unlock()
   216  
   217  			// check we had enough keys and wins etc to have tested all scenarios
   218  			if kills < 4 {
   219  				Fail(fmt.Sprintf("had %d kills", kills))
   220  			}
   221  			if wins < 4 {
   222  				Fail(fmt.Sprintf("won only %d elections for %d kills", wins, kills))
   223  			}
   224  			if lost < 4 {
   225  				Fail(fmt.Sprintf("lost only %d elections", lost))
   226  			}
   227  			if maxActive > 1 {
   228  				Fail(fmt.Sprintf("Had %d leaders", maxActive))
   229  			}
   230  		})
   231  	})
   232  })
   233  
   234  func startJSServer(t GinkgoTInterface) (*server.Server, *nats.Conn) {
   235  	t.Helper()
   236  
   237  	d, err := os.MkdirTemp("", "jstest")
   238  	if err != nil {
   239  		t.Fatalf("temp dir could not be made: %s", err)
   240  	}
   241  
   242  	opts := &server.Options{
   243  		JetStream: true,
   244  		StoreDir:  d,
   245  		Port:      -1,
   246  		Host:      "localhost",
   247  		LogFile:   "/dev/stdout",
   248  		Trace:     true,
   249  	}
   250  
   251  	s, err := server.NewServer(opts)
   252  	if err != nil {
   253  		t.Fatal("server start failed: ", err)
   254  	}
   255  
   256  	go s.Start()
   257  	if !s.ReadyForConnections(10 * time.Second) {
   258  		t.Error("nats server did not start")
   259  	}
   260  
   261  	nc, err := nats.Connect(s.ClientURL(), nats.UseOldRequestStyle())
   262  	if err != nil {
   263  		t.Fatalf("client start failed: %s", err)
   264  	}
   265  
   266  	return s, nc
   267  }