github.com/xmidt-org/webpa-common@v1.11.9/service/consul/registrar.go (about)

     1  package consul
     2  
     3  import (
     4  	"fmt"
     5  	"sync"
     6  	"time"
     7  
     8  	"github.com/go-kit/kit/log"
     9  	"github.com/go-kit/kit/log/level"
    10  	"github.com/go-kit/kit/sd"
    11  	gokitconsul "github.com/go-kit/kit/sd/consul"
    12  	"github.com/hashicorp/consul/api"
    13  	"github.com/xmidt-org/webpa-common/logging"
    14  )
    15  
    16  // passFormat returns a closure that produces the output for a passing TTL, given the current system time
    17  func passFormat(serviceID string) func(time.Time) string {
    18  	return func(t time.Time) string {
    19  		return fmt.Sprintf("%s passed at %s", serviceID, t.UTC())
    20  	}
    21  }
    22  
    23  // failFormat returns a closure that produces the output for a critical TTL, given the current system time
    24  func failFormat(serviceID string) func(time.Time) string {
    25  	return func(t time.Time) string {
    26  		return fmt.Sprintf("%s failed at %s", serviceID, t.UTC())
    27  	}
    28  }
    29  
    30  func defaultTickerFactory(d time.Duration) (<-chan time.Time, func()) {
    31  	t := time.NewTicker(d)
    32  	return t.C, t.Stop
    33  }
    34  
    35  var tickerFactory = defaultTickerFactory
    36  
    37  // ttlUpdater represents any object which can update the TTL status on the remote consul cluster.
    38  // The consul api Client implements this interface.
    39  type ttlUpdater interface {
    40  	UpdateTTL(checkID, output, status string) error
    41  }
    42  
    43  // ttlCheck holds the relevant information for managing a TTL check
    44  type ttlCheck struct {
    45  	checkID    string
    46  	interval   time.Duration
    47  	logger     log.Logger
    48  	passFormat func(time.Time) string
    49  	failFormat func(time.Time) string
    50  }
    51  
    52  func (tc ttlCheck) updatePeriodically(updater ttlUpdater, shutdown <-chan struct{}) {
    53  	ticker, stop := tickerFactory(tc.interval)
    54  	defer stop()
    55  	defer func() {
    56  		if err := updater.UpdateTTL(tc.checkID, tc.failFormat(time.Now()), "fail"); err != nil {
    57  			tc.logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "error while updating TTL to critical", logging.ErrorKey(), err)
    58  		}
    59  	}()
    60  
    61  	tc.logger.Log(level.Key(), level.InfoValue(), logging.MessageKey(), "starting TTL updater")
    62  
    63  	// we log an error only on the first error, and then an info message if and when the update recovers.
    64  	// this avoids filling up the server's logs with what are almost certainly just duplicate errors over and over.
    65  	successiveErrorCount := 0
    66  
    67  	for {
    68  		select {
    69  		case t := <-ticker:
    70  			if err := updater.UpdateTTL(tc.checkID, tc.passFormat(t), "pass"); err != nil {
    71  				successiveErrorCount++
    72  				if successiveErrorCount == 1 {
    73  					tc.logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "error while updating TTL to passing", logging.ErrorKey(), err)
    74  				}
    75  			} else if successiveErrorCount > 0 {
    76  				tc.logger.Log(level.Key(), level.InfoValue(), logging.MessageKey(), "update TTL success", "previousErrorCount", successiveErrorCount)
    77  				successiveErrorCount = 0
    78  			}
    79  
    80  		case <-shutdown:
    81  			tc.logger.Log(level.Key(), level.ErrorValue(), logging.MessageKey(), "TTL updater shutdown")
    82  			return
    83  		}
    84  	}
    85  }
    86  
    87  // appendTTLCheck conditionally creates a ttlCheck for the given agent check if and only if the agent check is configured with a TTL.
    88  // If the agent check is nil or has no TTL, this function returns ttlChecks unmodified with no error.
    89  func appendTTLCheck(logger log.Logger, serviceID string, agentCheck *api.AgentServiceCheck, ttlChecks []ttlCheck) ([]ttlCheck, error) {
    90  	if agentCheck == nil || len(agentCheck.TTL) == 0 {
    91  		return ttlChecks, nil
    92  	}
    93  
    94  	ttl, err := time.ParseDuration(agentCheck.TTL)
    95  	if err != nil {
    96  		return nil, err
    97  	}
    98  
    99  	interval := ttl / 2
   100  	if interval < 1 {
   101  		return nil, fmt.Errorf("TTL %s is too small", agentCheck.TTL)
   102  	}
   103  
   104  	ttlChecks = append(
   105  		ttlChecks,
   106  		ttlCheck{
   107  			checkID:  agentCheck.CheckID,
   108  			interval: interval,
   109  			logger: log.With(
   110  				logger,
   111  				"serviceID", serviceID,
   112  				"checkID", agentCheck.CheckID,
   113  				"ttl", agentCheck.TTL,
   114  				"interval", interval.String(),
   115  			),
   116  			passFormat: passFormat(serviceID),
   117  			failFormat: failFormat(serviceID),
   118  		},
   119  	)
   120  
   121  	return ttlChecks, nil
   122  }
   123  
   124  // ttlRegistrar is an sd.Registrar that binds one or more TTL updates to the Register/Deregister lifecycle.
   125  // When Register is called, a goroutine is spawned for each TTL check that invokes UpdateTTL on an interval.
   126  // When Dereigster is called, any goroutines spawned are stopped and each check is set to fail (critical).
   127  type ttlRegistrar struct {
   128  	logger    log.Logger
   129  	serviceID string
   130  	registrar sd.Registrar
   131  	updater   ttlUpdater
   132  	checks    []ttlCheck
   133  
   134  	lifecycleLock sync.Mutex
   135  	shutdown      chan struct{}
   136  }
   137  
   138  // NewRegistrar creates an sd.Registrar, binding any TTL checks to the Register/Deregister lifecycle as needed.
   139  func NewRegistrar(c gokitconsul.Client, u ttlUpdater, r *api.AgentServiceRegistration, logger log.Logger) (sd.Registrar, error) {
   140  	var (
   141  		ttlChecks []ttlCheck
   142  		err       error
   143  	)
   144  
   145  	ttlChecks, err = appendTTLCheck(logger, r.ID, r.Check, ttlChecks)
   146  	if err != nil {
   147  		return nil, err
   148  	}
   149  
   150  	for _, agentCheck := range r.Checks {
   151  		ttlChecks, err = appendTTLCheck(logger, r.ID, agentCheck, ttlChecks)
   152  		if err != nil {
   153  			return nil, err
   154  		}
   155  	}
   156  
   157  	var registrar sd.Registrar = gokitconsul.NewRegistrar(c, r, logger)
   158  
   159  	// decorate the given registrar if we have any TTL checks
   160  	if len(ttlChecks) > 0 {
   161  		registrar = &ttlRegistrar{
   162  			logger:    logger,
   163  			serviceID: r.ID,
   164  			registrar: registrar,
   165  			updater:   u,
   166  			checks:    ttlChecks,
   167  		}
   168  	}
   169  
   170  	return registrar, nil
   171  }
   172  
   173  func (tr *ttlRegistrar) Register() {
   174  	defer tr.lifecycleLock.Unlock()
   175  	tr.lifecycleLock.Lock()
   176  
   177  	if tr.shutdown != nil {
   178  		return
   179  	}
   180  
   181  	tr.registrar.Register()
   182  	tr.shutdown = make(chan struct{})
   183  	for _, tc := range tr.checks {
   184  		go tc.updatePeriodically(tr.updater, tr.shutdown)
   185  	}
   186  }
   187  
   188  func (tr *ttlRegistrar) Deregister() {
   189  	defer tr.lifecycleLock.Unlock()
   190  	tr.lifecycleLock.Lock()
   191  
   192  	if tr.shutdown == nil {
   193  		return
   194  	}
   195  
   196  	close(tr.shutdown)
   197  	tr.shutdown = nil
   198  	tr.registrar.Deregister()
   199  }