github.com/choria-io/go-choria@v0.28.1-0.20240416190746-b3bf9c7d5a45/providers/election/streams/election.go (about)

     1  // Copyright (c) 2021-2023, R.I. Pienaar and the Choria Project contributors
     2  //
     3  // SPDX-License-Identifier: Apache-2.0
     4  
     5  package election
     6  
     7  import (
     8  	"context"
     9  	"fmt"
    10  	"math"
    11  	"math/rand"
    12  	"sync"
    13  	"time"
    14  
    15  	"github.com/choria-io/go-choria/backoff"
    16  	"github.com/nats-io/nats.go"
    17  )
    18  
    19  // Backoff controls the interval of campaigns
    20  type Backoff interface {
    21  	// Duration returns the time to sleep for the nth invocation
    22  	Duration(n int) time.Duration
    23  }
    24  
    25  // State indicates the current state of the election
    26  type State uint
    27  
    28  const (
    29  	// UnknownState indicates the state is unknown, like when the election is not started
    30  	UnknownState State = 0
    31  	// CandidateState is a campaigner that is not the leader
    32  	CandidateState State = 1
    33  	// LeaderState is the leader
    34  	LeaderState State = 2
    35  )
    36  
    37  var (
    38  	stateNames = map[State]string{
    39  		UnknownState:   "unknown",
    40  		CandidateState: "candidate",
    41  		LeaderState:    "leader",
    42  	}
    43  )
    44  
    45  func (s State) String() string {
    46  	return stateNames[s]
    47  }
    48  
    49  // implements inter.Election
    50  type election struct {
    51  	opts  *options
    52  	state State
    53  
    54  	ctx        context.Context
    55  	cancel     context.CancelFunc
    56  	started    bool
    57  	lastSeq    uint64
    58  	tries      int
    59  	notifyNext bool
    60  
    61  	mu sync.Mutex
    62  }
    63  
    64  var skipValidate bool
    65  
    66  // NewElection creates a new leader election for a member name.
    67  //
    68  // Leader election is done using a KV Bucket where each key is an election, the key therefore
    69  // should be a unique identifier for the election.  Bucket is a loaded bucket using either NATS
    70  // libraries of the helper in the choria package.
    71  //
    72  // Buckets can be created using the NATS libraries or the "choria kv add ELECTION --ttl 30s --replicas 3"
    73  // command, here we set a 30s TTL on the bucket which would be used to influence campaign frequency and
    74  // the failover time after an outage of the leader. The smallest allowed TTL is 5 seconds though we suggest
    75  // picking the biggest number in the range of 30 to 60 seconds that works for your use case.
    76  //
    77  // A standard bucket CHORIA_LEADER_ELECTION gets made for Choria Streams and it's replicas and ttl is
    78  // configurable in the config file.
    79  func NewElection(name string, key string, bucket nats.KeyValue, opts ...Option) (*election, error) {
    80  	e := &election{
    81  		state:   UnknownState,
    82  		lastSeq: math.MaxUint64,
    83  		opts: &options{
    84  			name:   name,
    85  			key:    key,
    86  			bucket: bucket,
    87  		},
    88  	}
    89  
    90  	for _, opt := range opts {
    91  		opt(e.opts)
    92  	}
    93  
    94  	return e, nil
    95  }
    96  
    97  func (e *election) configure(ctx context.Context) error {
    98  	var status nats.KeyValueStatus
    99  
   100  	err := backoff.Default.For(ctx, func(try int) error {
   101  		var err error
   102  
   103  		status, err = e.opts.bucket.Status()
   104  		if err != nil {
   105  			e.debugf("Obtaining bucket stats failed on try %d: %v", try, err)
   106  		}
   107  
   108  		return err
   109  	})
   110  	if err != nil {
   111  		return err
   112  	}
   113  
   114  	e.opts.ttl = status.TTL()
   115  	if e.opts.cInterval == 0 {
   116  		e.opts.cInterval = time.Duration(float64(e.opts.ttl) * 0.75)
   117  	}
   118  
   119  	if !skipValidate {
   120  		if e.opts.ttl < time.Second {
   121  			return fmt.Errorf("bucket TTL should be 1 second or more")
   122  		}
   123  
   124  		if e.opts.ttl > time.Hour {
   125  			return fmt.Errorf("bucket TTL should be less than or equal to 1 hour")
   126  		}
   127  
   128  		if e.opts.cInterval.Seconds() < 1 {
   129  			return fmt.Errorf("campaign interval %v too small", e.opts.cInterval)
   130  		}
   131  
   132  		if e.opts.ttl.Seconds()-e.opts.cInterval.Seconds() < 1 {
   133  			return fmt.Errorf("campaign interval %v is too close to bucket ttl %v", e.opts.cInterval, e.opts.ttl)
   134  		}
   135  	}
   136  
   137  	e.debugf("Campaign interval: %v", e.opts.cInterval)
   138  
   139  	return nil
   140  }
   141  
   142  func (e *election) debugf(format string, a ...any) {
   143  	if e.opts.debug == nil {
   144  		return
   145  	}
   146  	e.opts.debug(format, a...)
   147  }
   148  
   149  func (e *election) campaignForLeadership() error {
   150  	campaignsCounter.WithLabelValues(e.opts.key, e.opts.name, stateNames[CandidateState]).Inc()
   151  
   152  	seq, err := e.opts.bucket.Create(e.opts.key, []byte(e.opts.name))
   153  	if err != nil {
   154  		e.tries++
   155  		return nil
   156  	}
   157  
   158  	e.lastSeq = seq
   159  	e.state = LeaderState
   160  	e.tries = 0
   161  	e.notifyNext = true // sets state that would notify about win on next campaign
   162  	leaderGauge.WithLabelValues(e.opts.key, e.opts.name).Set(1)
   163  
   164  	return nil
   165  }
   166  
   167  func (e *election) maintainLeadership() error {
   168  	campaignsCounter.WithLabelValues(e.opts.key, e.opts.name, stateNames[LeaderState]).Inc()
   169  
   170  	seq, err := e.opts.bucket.Update(e.opts.key, []byte(e.opts.name), e.lastSeq)
   171  	if err != nil {
   172  		e.debugf("key update failed, moving to candidate state: %v", err)
   173  		e.state = CandidateState
   174  		e.lastSeq = math.MaxUint64
   175  
   176  		leaderGauge.WithLabelValues(e.opts.key, e.opts.name).Set(0)
   177  
   178  		if e.opts.lostCb != nil {
   179  			e.opts.lostCb()
   180  		}
   181  
   182  		return err
   183  	}
   184  	e.lastSeq = seq
   185  
   186  	// we wait till the next campaign to notify that we are leader to give others a chance to stand down
   187  	if e.notifyNext {
   188  		e.notifyNext = false
   189  		if e.opts.wonCb != nil {
   190  			ctxSleep(e.ctx, 200*time.Millisecond)
   191  			e.opts.wonCb()
   192  		}
   193  	}
   194  
   195  	return nil
   196  }
   197  
   198  func (e *election) try() error {
   199  	e.mu.Lock()
   200  	defer e.mu.Unlock()
   201  
   202  	if e.opts.campaignCb != nil {
   203  		e.opts.campaignCb(e.state)
   204  	}
   205  
   206  	switch e.state {
   207  	case LeaderState:
   208  		return e.maintainLeadership()
   209  
   210  	case CandidateState:
   211  		return e.campaignForLeadership()
   212  
   213  	default:
   214  		return fmt.Errorf("campaigned while in unknown state")
   215  	}
   216  }
   217  
   218  func (e *election) campaign(wg *sync.WaitGroup) error {
   219  	defer wg.Done()
   220  
   221  	e.mu.Lock()
   222  	e.state = CandidateState
   223  	e.mu.Unlock()
   224  
   225  	// spread out startups a bit
   226  	splay := time.Duration(rand.Intn(5000)) * time.Millisecond
   227  	ctxSleep(e.ctx, splay)
   228  
   229  	var ticker *time.Ticker
   230  	if e.opts.bo != nil {
   231  		d := e.opts.bo.Duration(0)
   232  		campaignIntervalGauge.WithLabelValues(e.opts.key, e.opts.name).Set(d.Seconds())
   233  		ticker = time.NewTicker(d)
   234  	} else {
   235  		campaignIntervalGauge.WithLabelValues(e.opts.key, e.opts.name).Set(e.opts.cInterval.Seconds())
   236  		ticker = time.NewTicker(e.opts.cInterval)
   237  	}
   238  
   239  	tick := func() {
   240  		err := e.try()
   241  		if err != nil {
   242  			e.debugf("election attempt failed: %v", err)
   243  		}
   244  
   245  		if e.opts.bo != nil {
   246  			d := e.opts.bo.Duration(e.tries)
   247  			campaignIntervalGauge.WithLabelValues(e.opts.key, e.opts.name).Set(d.Seconds())
   248  			ticker.Reset(d)
   249  		}
   250  	}
   251  
   252  	// initial campaign
   253  	tick()
   254  
   255  	for {
   256  		select {
   257  		case <-ticker.C:
   258  			tick()
   259  
   260  		case <-e.ctx.Done():
   261  			ticker.Stop()
   262  			e.stop()
   263  
   264  			if e.opts.lostCb != nil && e.IsLeader() {
   265  				e.debugf("Calling leader lost during shutdown")
   266  				e.opts.lostCb()
   267  			}
   268  
   269  			return nil
   270  		}
   271  	}
   272  }
   273  
   274  func (e *election) stop() {
   275  	e.mu.Lock()
   276  	e.started = false
   277  	e.cancel()
   278  	e.state = CandidateState
   279  	e.lastSeq = math.MaxUint64
   280  	e.mu.Unlock()
   281  }
   282  
   283  func (e *election) Start(ctx context.Context) error {
   284  	e.mu.Lock()
   285  	if e.started {
   286  		e.mu.Unlock()
   287  		return fmt.Errorf("already running")
   288  	}
   289  
   290  	err := e.configure(ctx)
   291  	if err != nil {
   292  		return err
   293  	}
   294  
   295  	e.ctx, e.cancel = context.WithCancel(ctx)
   296  	e.started = true
   297  	e.mu.Unlock()
   298  
   299  	wg := &sync.WaitGroup{}
   300  	wg.Add(1)
   301  
   302  	err = e.campaign(wg)
   303  	if err != nil {
   304  		e.stop()
   305  		return err
   306  	}
   307  
   308  	wg.Wait()
   309  
   310  	return nil
   311  }
   312  
   313  func (e *election) Stop() {
   314  	e.mu.Lock()
   315  	defer e.mu.Unlock()
   316  
   317  	if !e.started {
   318  		return
   319  	}
   320  
   321  	if e.cancel != nil {
   322  		e.cancel()
   323  	}
   324  
   325  	e.stop()
   326  }
   327  
   328  func (e *election) IsLeader() bool {
   329  	e.mu.Lock()
   330  	defer e.mu.Unlock()
   331  
   332  	// it's only leader after the next successful campaign
   333  	return e.state == LeaderState && !e.notifyNext
   334  }
   335  
   336  func (e *election) State() State {
   337  	e.mu.Lock()
   338  	defer e.mu.Unlock()
   339  
   340  	return e.state
   341  }
   342  
   343  func ctxSleep(ctx context.Context, duration time.Duration) error {
   344  	if ctx.Err() != nil {
   345  		return ctx.Err()
   346  	}
   347  
   348  	sctx, cancel := context.WithTimeout(ctx, duration)
   349  	defer cancel()
   350  
   351  	<-sctx.Done()
   352  
   353  	return ctx.Err()
   354  }