github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/election/streams/election_test.go (about) 1 // Copyright (c) 2021-2023, R.I. Pienaar and the Choria Project contributors 2 // 3 // SPDX-License-Identifier: Apache-2.0 4 5 package election 6 7 import ( 8 "context" 9 "fmt" 10 "os" 11 "sync" 12 "testing" 13 "time" 14 15 "github.com/nats-io/nats-server/v2/server" 16 "github.com/nats-io/nats.go" 17 . "github.com/onsi/ginkgo/v2" 18 . "github.com/onsi/gomega" 19 ) 20 21 func TestLeader(t *testing.T) { 22 RegisterFailHandler(Fail) 23 RunSpecs(t, "Providers/Election/Streams") 24 } 25 26 var _ = Describe("Choria KV Leader Election", func() { 27 var ( 28 srv *server.Server 29 nc *nats.Conn 30 js nats.KeyValueManager 31 kv nats.KeyValue 32 err error 33 debugger func(f string, a ...any) 34 ) 35 36 BeforeEach(func() { 37 skipValidate = false 38 srv, nc = startJSServer(GinkgoT()) 39 js, err = nc.JetStream() 40 Expect(err).ToNot(HaveOccurred()) 41 42 kv, err = js.CreateKeyValue(&nats.KeyValueConfig{ 43 Bucket: "LEADER_ELECTION", 44 TTL: 500 * time.Millisecond, 45 }) 46 Expect(err).ToNot(HaveOccurred()) 47 debugger = func(f string, a ...any) { 48 fmt.Fprintf(GinkgoWriter, fmt.Sprintf("%s: %s\n", time.Now(), f), a...) 49 } 50 }) 51 52 AfterEach(func() { 53 nc.Close() 54 srv.Shutdown() 55 srv.WaitForShutdown() 56 if srv.StoreDir() != "" { 57 os.RemoveAll(srv.StoreDir()) 58 } 59 }) 60 61 Describe("Election", func() { 62 It("Should validate the TTL", func() { 63 kv, err := js.CreateKeyValue(&nats.KeyValueConfig{ 64 Bucket: "LE", 65 TTL: 100 * time.Millisecond, 66 }) 67 Expect(err).ToNot(HaveOccurred()) 68 69 election, err := NewElection("test", "test.key", kv) 70 Expect(err).ToNot(HaveOccurred()) 71 err = election.Start(context.Background()) 72 Expect(err).To(MatchError("bucket TTL should be 1 second or more")) 73 74 err = js.DeleteKeyValue("LE") 75 Expect(err).ToNot(HaveOccurred()) 76 77 kv, err = js.CreateKeyValue(&nats.KeyValueConfig{ 78 Bucket: "LE", 79 TTL: 24 * time.Hour, 80 }) 81 Expect(err).ToNot(HaveOccurred()) 82 83 election, err = NewElection("test", "test.key", kv) 84 Expect(err).ToNot(HaveOccurred()) 85 err = election.Start(context.Background()) 86 Expect(err).To(MatchError("bucket TTL should be less than or equal to 1 hour")) 87 }) 88 89 It("Should allow 5 second TTLs", func() { 90 kv, err := js.CreateKeyValue(&nats.KeyValueConfig{ 91 Bucket: "LE", 92 TTL: 5 * time.Second, 93 }) 94 Expect(err).ToNot(HaveOccurred()) 95 96 _, err = NewElection("test", "test.key", kv) 97 Expect(err).ToNot(HaveOccurred()) 98 }) 99 100 It("Should correctly manage leadership", func() { 101 var ( 102 wins = 0 103 lost = 0 104 active = make(map[string]struct{}) 105 maxActive = 0 106 other = 0 107 wg = &sync.WaitGroup{} 108 mu = sync.Mutex{} 109 ) 110 111 skipValidate = true 112 113 ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) 114 defer cancel() 115 116 worker := func(wg *sync.WaitGroup, i int, key string) { 117 defer wg.Done() 118 119 name := fmt.Sprintf("member %d", i) 120 121 winCb := func() { 122 mu.Lock() 123 wins++ 124 active[name] = struct{}{} 125 act := len(active) 126 if act > maxActive { 127 maxActive = act 128 } 129 mu.Unlock() 130 131 debugger("%d became leader with %d active leaders", i, act) 132 } 133 134 lostCb := func() { 135 mu.Lock() 136 lost++ 137 delete(active, name) 138 mu.Unlock() 139 debugger("%d lost leadership", i) 140 } 141 142 elect, err := NewElection(name, key, kv, 143 OnWon(winCb), 144 OnLost(lostCb), 145 WithDebug(debugger)) 146 Expect(err).ToNot(HaveOccurred()) 147 148 err = elect.Start(ctx) 149 Expect(err).ToNot(HaveOccurred()) 150 } 151 152 for i := 0; i < 10; i++ { 153 wg.Add(1) 154 go worker(wg, i, "election") 155 } 156 157 // make sure another election is not interrupted 158 otherWorker := func(wg *sync.WaitGroup, i int) { 159 defer wg.Done() 160 161 elect, err := NewElection(fmt.Sprintf("other %d", i), "other", kv, 162 OnWon(func() { 163 mu.Lock() 164 debugger("other %d gained leader", i) 165 other++ 166 mu.Unlock() 167 }), 168 OnLost(func() { 169 defer GinkgoRecover() 170 debugger("other %d lost leader", i) 171 Fail(fmt.Sprintf("Other %d election was lost", i)) 172 })) 173 Expect(err).ToNot(HaveOccurred()) 174 175 err = elect.Start(ctx) 176 Expect(err).ToNot(HaveOccurred()) 177 } 178 wg.Add(2) 179 go otherWorker(wg, 1) 180 go otherWorker(wg, 2) 181 182 // test failure scenarios, either the key gets deleted to allow a Create() to succeed 183 // or it gets corrupted by putting a item in the key that would therefore change the sequence 184 // causing next campaign by the leader to fail. The leader will stand-down, all campaigns will 185 // fail until the corruption is removed by the MaxAge limit 186 kills := 0 187 for { 188 if ctxSleep(ctx, 400*time.Millisecond) != nil { 189 break 190 } 191 192 mu.Lock() 193 act := len(active) 194 mu.Unlock() 195 196 // only corrupt when there is a leader 197 if act == 0 { 198 continue 199 } 200 201 kills++ 202 if kills%3 == 0 { 203 debugger("deleting key") 204 Expect(kv.Delete("election")).ToNot(HaveOccurred()) 205 } else { 206 debugger("corrupting key") 207 _, err := kv.Put("election", nil) 208 Expect(err).ToNot(HaveOccurred()) 209 } 210 } 211 212 wg.Wait() 213 214 mu.Lock() 215 defer mu.Unlock() 216 217 // check we had enough keys and wins etc to have tested all scenarios 218 if kills < 4 { 219 Fail(fmt.Sprintf("had %d kills", kills)) 220 } 221 if wins < 4 { 222 Fail(fmt.Sprintf("won only %d elections for %d kills", wins, kills)) 223 } 224 if lost < 4 { 225 Fail(fmt.Sprintf("lost only %d elections", lost)) 226 } 227 if maxActive > 1 { 228 Fail(fmt.Sprintf("Had %d leaders", maxActive)) 229 } 230 }) 231 }) 232 }) 233 234 func startJSServer(t GinkgoTInterface) (*server.Server, *nats.Conn) { 235 t.Helper() 236 237 d, err := os.MkdirTemp("", "jstest") 238 if err != nil { 239 t.Fatalf("temp dir could not be made: %s", err) 240 } 241 242 opts := &server.Options{ 243 JetStream: true, 244 StoreDir: d, 245 Port: -1, 246 Host: "localhost", 247 LogFile: "/dev/stdout", 248 Trace: true, 249 } 250 251 s, err := server.NewServer(opts) 252 if err != nil { 253 t.Fatal("server start failed: ", err) 254 } 255 256 go s.Start() 257 if !s.ReadyForConnections(10 * time.Second) { 258 t.Error("nats server did not start") 259 } 260 261 nc, err := nats.Connect(s.ClientURL(), nats.UseOldRequestStyle()) 262 if err != nil { 263 t.Fatalf("client start failed: %s", err) 264 } 265 266 return s, nc 267 }