github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/registry/consul/consul.go (about)

     1  package consul
     2  
     3  import (
     4  	"crypto/tls"
     5  	"errors"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"runtime"
    10  	"strconv"
    11  	"sync"
    12  	"time"
    13  
    14  	consul "github.com/hashicorp/consul/api"
    15  	hash "github.com/mitchellh/hashstructure"
    16  	mnet "github.com/volts-dev/volts/internal/net"
    17  	"github.com/volts-dev/volts/registry"
    18  )
    19  
    20  type consulRegistry struct {
    21  	Address []string
    22  	opts    *registry.Config
    23  	client  *consul.Client
    24  	config  *consul.Config
    25  
    26  	// connect enabled
    27  	connect bool
    28  
    29  	queryOptions *consul.QueryOptions
    30  
    31  	sync.Mutex
    32  	register map[string]uint64
    33  	// lastChecked tracks when a node was last checked as existing in Consul
    34  	lastChecked map[string]time.Time
    35  }
    36  
    37  func init() {
    38  	registry.Register("consul", New)
    39  }
    40  
    41  func New(opts ...registry.Option) registry.IRegistry {
    42  	var defaultOpts []registry.Option
    43  	defaultOpts = append(defaultOpts,
    44  		registry.WithName("consul"),
    45  		registry.Timeout(time.Millisecond*100),
    46  	)
    47  
    48  	cr := &consulRegistry{
    49  		opts:        registry.NewConfig(append(defaultOpts, opts...)...),
    50  		register:    make(map[string]uint64),
    51  		lastChecked: make(map[string]time.Time),
    52  		queryOptions: &consul.QueryOptions{
    53  			AllowStale: true,
    54  		},
    55  	}
    56  	configure(cr)
    57  	return cr
    58  }
    59  
    60  func getDeregisterTTL(t time.Duration) time.Duration {
    61  	// splay slightly for the watcher?
    62  	splay := time.Second * 5
    63  	deregTTL := t + splay
    64  
    65  	// consul has a minimum timeout on deregistration of 1 minute.
    66  	if t < time.Minute {
    67  		deregTTL = time.Minute + splay
    68  	}
    69  
    70  	return deregTTL
    71  }
    72  
    73  func newTransport(config *tls.Config) *http.Transport {
    74  	if config == nil {
    75  		config = &tls.Config{
    76  			InsecureSkipVerify: true,
    77  		}
    78  	}
    79  
    80  	t := &http.Transport{
    81  		Proxy: http.ProxyFromEnvironment,
    82  		Dial: (&net.Dialer{
    83  			Timeout:   30 * time.Second,
    84  			KeepAlive: 30 * time.Second,
    85  		}).Dial,
    86  		TLSHandshakeTimeout: 10 * time.Second,
    87  		TLSClientConfig:     config,
    88  	}
    89  	runtime.SetFinalizer(&t, func(tr **http.Transport) {
    90  		(*tr).CloseIdleConnections()
    91  	})
    92  	return t
    93  }
    94  
    95  func configure(c *consulRegistry) {
    96  	// use default non pooled config
    97  	config := consul.DefaultNonPooledConfig()
    98  	c.opts.Name = c.String()
    99  	if c.opts.Context != nil {
   100  		// Use the consul config passed in the options, if available
   101  		if co, ok := c.opts.Context.Value("consul_config").(*consul.Config); ok {
   102  			config = co
   103  		}
   104  		if cn, ok := c.opts.Context.Value("consul_connect").(bool); ok {
   105  			c.connect = cn
   106  		}
   107  
   108  		// Use the consul query options passed in the options, if available
   109  		if qo, ok := c.opts.Context.Value("consul_query_options").(*consul.QueryOptions); ok && qo != nil {
   110  			c.queryOptions = qo
   111  		}
   112  		if as, ok := c.opts.Context.Value("consul_allow_stale").(bool); ok {
   113  			c.queryOptions.AllowStale = as
   114  		}
   115  	}
   116  
   117  	// check if there are any addrs
   118  	var addrs []string
   119  
   120  	// iterate the options addresses
   121  	for _, address := range c.opts.Addrs {
   122  		// check we have a port
   123  		addr, port, err := net.SplitHostPort(address)
   124  		if ae, ok := err.(*net.AddrError); ok && ae.Err == "missing port in address" {
   125  			port = "8500"
   126  			addr = address
   127  			addrs = append(addrs, net.JoinHostPort(addr, port))
   128  		} else if err == nil {
   129  			addrs = append(addrs, net.JoinHostPort(addr, port))
   130  		}
   131  	}
   132  
   133  	// set the addrs
   134  	if len(addrs) > 0 {
   135  		c.Address = addrs
   136  		config.Address = c.Address[0]
   137  	}
   138  
   139  	if config.HttpClient == nil {
   140  		config.HttpClient = new(http.Client)
   141  	}
   142  
   143  	// requires secure connection?
   144  	if c.opts.Secure || c.opts.TlsConfig != nil {
   145  		config.Scheme = "https"
   146  		// We're going to support InsecureSkipVerify
   147  		config.HttpClient.Transport = newTransport(c.opts.TlsConfig)
   148  	}
   149  
   150  	// set timeout
   151  	if c.opts.Timeout > 0 {
   152  		config.HttpClient.Timeout = c.opts.Timeout
   153  	}
   154  
   155  	// set the config
   156  	c.config = config
   157  
   158  	// remove client
   159  	c.client = nil
   160  
   161  	// setup the client
   162  	c.Client()
   163  }
   164  
   165  func (c *consulRegistry) Init(opts ...registry.Option) error {
   166  	c.opts.Init(opts...)
   167  	configure(c)
   168  	return nil
   169  }
   170  
   171  func (c *consulRegistry) Deregister(s *registry.Service, opts ...registry.Option) error {
   172  	if len(s.Nodes) == 0 {
   173  		return errors.New("Require at least one node")
   174  	}
   175  
   176  	// delete our hash and time check of the service
   177  	c.Lock()
   178  	delete(c.register, s.Name)
   179  	delete(c.lastChecked, s.Name)
   180  	c.Unlock()
   181  
   182  	node := s.Nodes[0]
   183  	return c.Client().Agent().ServiceDeregister(node.Id)
   184  }
   185  
   186  func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Option) error {
   187  	if len(s.Nodes) == 0 {
   188  		return errors.New("Require at least one node")
   189  	}
   190  
   191  	var regTCPCheck bool
   192  	var regInterval time.Duration
   193  
   194  	var options registry.Config
   195  	for _, o := range opts {
   196  		o(&options)
   197  	}
   198  
   199  	if c.opts.Context != nil {
   200  		if tcpCheckInterval, ok := c.opts.Context.Value("consul_tcp_check").(time.Duration); ok {
   201  			regTCPCheck = true
   202  			regInterval = tcpCheckInterval
   203  		}
   204  	}
   205  
   206  	// create hash of service; uint64
   207  	h, err := hash.Hash(s, nil)
   208  	if err != nil {
   209  		return err
   210  	}
   211  
   212  	// use first node
   213  	node := s.Nodes[0]
   214  
   215  	// get existing hash and last checked time
   216  	c.Lock()
   217  	v, ok := c.register[s.Name]
   218  	lastChecked := c.lastChecked[s.Name]
   219  	c.Unlock()
   220  
   221  	// if it's already registered and matches then just pass the check
   222  	if ok && v == h {
   223  		if options.TTL == time.Duration(0) {
   224  			// ensure that our service hasn't been deregistered by Consul
   225  			if time.Since(lastChecked) <= getDeregisterTTL(regInterval) {
   226  				return nil
   227  			}
   228  			services, _, err := c.Client().Health().Checks(s.Name, c.queryOptions)
   229  			if err == nil {
   230  				for _, v := range services {
   231  					if v.ServiceID == node.Id {
   232  						return nil
   233  					}
   234  				}
   235  			}
   236  		} else {
   237  			// if the err is nil we're all good, bail out
   238  			// if not, we don't know what the state is, so full re-register
   239  			if err := c.Client().Agent().PassTTL("service:"+node.Id, ""); err == nil {
   240  				return nil
   241  			}
   242  		}
   243  	}
   244  
   245  	// encode the tags
   246  	tags := encodeMetadata(node.Metadata)
   247  	tags = append(tags, encodeEndpoints(s.Endpoints)...)
   248  	tags = append(tags, encodeVersion(s.Version)...)
   249  
   250  	var check *consul.AgentServiceCheck
   251  
   252  	if regTCPCheck {
   253  		deregTTL := getDeregisterTTL(regInterval)
   254  
   255  		check = &consul.AgentServiceCheck{
   256  			TCP:                            node.Address,
   257  			Interval:                       fmt.Sprintf("%v", regInterval),
   258  			DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL),
   259  		}
   260  
   261  		// if the TTL is greater than 0 create an associated check
   262  	} else if options.TTL > time.Duration(0) {
   263  		deregTTL := getDeregisterTTL(options.TTL)
   264  
   265  		check = &consul.AgentServiceCheck{
   266  			TTL:                            fmt.Sprintf("%v", options.TTL),
   267  			DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL),
   268  		}
   269  	}
   270  
   271  	host, pt, _ := net.SplitHostPort(node.Address)
   272  	if host == "" {
   273  		host = node.Address
   274  	}
   275  	port, _ := strconv.Atoi(pt)
   276  
   277  	// register the service
   278  	asr := &consul.AgentServiceRegistration{
   279  		ID:      node.Id,
   280  		Name:    s.Name,
   281  		Tags:    tags,
   282  		Port:    port,
   283  		Address: host,
   284  		Check:   check,
   285  	}
   286  
   287  	// Specify consul connect
   288  	if c.connect {
   289  		asr.Connect = &consul.AgentServiceConnect{
   290  			Native: true,
   291  		}
   292  	}
   293  
   294  	if err := c.Client().Agent().ServiceRegister(asr); err != nil {
   295  		return err
   296  	}
   297  
   298  	// save our hash and time check of the service
   299  	c.Lock()
   300  	c.register[s.Name] = h
   301  	c.lastChecked[s.Name] = time.Now()
   302  	c.Unlock()
   303  
   304  	// if the TTL is 0 we don't mess with the checks
   305  	if options.TTL == time.Duration(0) {
   306  		return nil
   307  	}
   308  
   309  	c.opts.LocalServices = append(c.opts.LocalServices, s)
   310  	// pass the healthcheck
   311  	return c.Client().Agent().PassTTL("service:"+node.Id, "")
   312  }
   313  
   314  func (m *consulRegistry) LocalServices() []*registry.Service {
   315  	return m.opts.LocalServices
   316  }
   317  
   318  func (c *consulRegistry) GetService(name string) ([]*registry.Service, error) {
   319  	var rsp []*consul.ServiceEntry
   320  	var err error
   321  
   322  	// if we're connect enabled only get connect services
   323  	if c.connect {
   324  		rsp, _, err = c.Client().Health().Connect(name, "", false, c.queryOptions)
   325  	} else {
   326  		rsp, _, err = c.Client().Health().Service(name, "", false, c.queryOptions)
   327  	}
   328  	if err != nil {
   329  		return nil, err
   330  	}
   331  
   332  	serviceMap := map[string]*registry.Service{}
   333  
   334  	for _, s := range rsp {
   335  		if s.Service.Service != name {
   336  			continue
   337  		}
   338  
   339  		// version is now a tag
   340  		version, _ := decodeVersion(s.Service.Tags)
   341  		// service ID is now the node id
   342  		id := s.Service.ID
   343  		// key is always the version
   344  		key := version
   345  
   346  		// address is service address
   347  		address := s.Service.Address
   348  
   349  		// use node address
   350  		if len(address) == 0 {
   351  			address = s.Node.Address
   352  		}
   353  
   354  		svc, ok := serviceMap[key]
   355  		if !ok {
   356  			svc = &registry.Service{
   357  				Endpoints: decodeEndpoints(s.Service.Tags),
   358  				Name:      s.Service.Service,
   359  				Version:   version,
   360  			}
   361  			serviceMap[key] = svc
   362  		}
   363  
   364  		var del bool
   365  
   366  		for _, check := range s.Checks {
   367  			// delete the node if the status is critical
   368  			if check.Status == "critical" {
   369  				del = true
   370  				break
   371  			}
   372  		}
   373  
   374  		// if delete then skip the node
   375  		if del {
   376  			continue
   377  		}
   378  
   379  		svc.Nodes = append(svc.Nodes, &registry.Node{
   380  			Id:       id,
   381  			Address:  mnet.HostPort(address, s.Service.Port),
   382  			Metadata: decodeMetadata(s.Service.Tags),
   383  		})
   384  	}
   385  
   386  	var services []*registry.Service
   387  	for _, service := range serviceMap {
   388  		services = append(services, service)
   389  	}
   390  	return services, nil
   391  }
   392  
   393  func (c *consulRegistry) ListServices() ([]*registry.Service, error) {
   394  	rsp, _, err := c.Client().Catalog().Services(c.queryOptions)
   395  	if err != nil {
   396  		return nil, err
   397  	}
   398  
   399  	var services []*registry.Service
   400  
   401  	for service := range rsp {
   402  		services = append(services, &registry.Service{Name: service})
   403  	}
   404  
   405  	return services, nil
   406  }
   407  
   408  func (c *consulRegistry) Watcher(opts ...registry.WatchOptions) (registry.Watcher, error) {
   409  	return newConsulWatcher(c, opts...)
   410  }
   411  
   412  func (c *consulRegistry) String() string {
   413  	return c.opts.Name
   414  }
   415  
   416  func (c *consulRegistry) Config() *registry.Config {
   417  	return c.opts
   418  }
   419  
   420  func (c *consulRegistry) Client() *consul.Client {
   421  	if c.client != nil {
   422  		return c.client
   423  	}
   424  
   425  	for _, addr := range c.Address {
   426  		// set the address
   427  		c.config.Address = addr
   428  
   429  		// create a new client
   430  		tmpClient, _ := consul.NewClient(c.config)
   431  
   432  		// test the client
   433  		_, err := tmpClient.Agent().Host()
   434  		if err != nil {
   435  			continue
   436  		}
   437  
   438  		// set the client
   439  		c.client = tmpClient
   440  		return c.client
   441  	}
   442  
   443  	// set the default
   444  	c.client, _ = consul.NewClient(c.config)
   445  
   446  	// return the client
   447  	return c.client
   448  }