github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/registry/consul/consul.go (about) 1 package consul 2 3 import ( 4 "crypto/tls" 5 "errors" 6 "fmt" 7 "net" 8 "net/http" 9 "runtime" 10 "strconv" 11 "sync" 12 "time" 13 14 consul "github.com/hashicorp/consul/api" 15 hash "github.com/mitchellh/hashstructure" 16 mnet "github.com/volts-dev/volts/internal/net" 17 "github.com/volts-dev/volts/registry" 18 ) 19 20 type consulRegistry struct { 21 Address []string 22 opts *registry.Config 23 client *consul.Client 24 config *consul.Config 25 26 // connect enabled 27 connect bool 28 29 queryOptions *consul.QueryOptions 30 31 sync.Mutex 32 register map[string]uint64 33 // lastChecked tracks when a node was last checked as existing in Consul 34 lastChecked map[string]time.Time 35 } 36 37 func init() { 38 registry.Register("consul", New) 39 } 40 41 func New(opts ...registry.Option) registry.IRegistry { 42 var defaultOpts []registry.Option 43 defaultOpts = append(defaultOpts, 44 registry.WithName("consul"), 45 registry.Timeout(time.Millisecond*100), 46 ) 47 48 cr := &consulRegistry{ 49 opts: registry.NewConfig(append(defaultOpts, opts...)...), 50 register: make(map[string]uint64), 51 lastChecked: make(map[string]time.Time), 52 queryOptions: &consul.QueryOptions{ 53 AllowStale: true, 54 }, 55 } 56 configure(cr) 57 return cr 58 } 59 60 func getDeregisterTTL(t time.Duration) time.Duration { 61 // splay slightly for the watcher? 62 splay := time.Second * 5 63 deregTTL := t + splay 64 65 // consul has a minimum timeout on deregistration of 1 minute. 66 if t < time.Minute { 67 deregTTL = time.Minute + splay 68 } 69 70 return deregTTL 71 } 72 73 func newTransport(config *tls.Config) *http.Transport { 74 if config == nil { 75 config = &tls.Config{ 76 InsecureSkipVerify: true, 77 } 78 } 79 80 t := &http.Transport{ 81 Proxy: http.ProxyFromEnvironment, 82 Dial: (&net.Dialer{ 83 Timeout: 30 * time.Second, 84 KeepAlive: 30 * time.Second, 85 }).Dial, 86 TLSHandshakeTimeout: 10 * time.Second, 87 TLSClientConfig: config, 88 } 89 runtime.SetFinalizer(&t, func(tr **http.Transport) { 90 (*tr).CloseIdleConnections() 91 }) 92 return t 93 } 94 95 func configure(c *consulRegistry) { 96 // use default non pooled config 97 config := consul.DefaultNonPooledConfig() 98 c.opts.Name = c.String() 99 if c.opts.Context != nil { 100 // Use the consul config passed in the options, if available 101 if co, ok := c.opts.Context.Value("consul_config").(*consul.Config); ok { 102 config = co 103 } 104 if cn, ok := c.opts.Context.Value("consul_connect").(bool); ok { 105 c.connect = cn 106 } 107 108 // Use the consul query options passed in the options, if available 109 if qo, ok := c.opts.Context.Value("consul_query_options").(*consul.QueryOptions); ok && qo != nil { 110 c.queryOptions = qo 111 } 112 if as, ok := c.opts.Context.Value("consul_allow_stale").(bool); ok { 113 c.queryOptions.AllowStale = as 114 } 115 } 116 117 // check if there are any addrs 118 var addrs []string 119 120 // iterate the options addresses 121 for _, address := range c.opts.Addrs { 122 // check we have a port 123 addr, port, err := net.SplitHostPort(address) 124 if ae, ok := err.(*net.AddrError); ok && ae.Err == "missing port in address" { 125 port = "8500" 126 addr = address 127 addrs = append(addrs, net.JoinHostPort(addr, port)) 128 } else if err == nil { 129 addrs = append(addrs, net.JoinHostPort(addr, port)) 130 } 131 } 132 133 // set the addrs 134 if len(addrs) > 0 { 135 c.Address = addrs 136 config.Address = c.Address[0] 137 } 138 139 if config.HttpClient == nil { 140 config.HttpClient = new(http.Client) 141 } 142 143 // requires secure connection? 144 if c.opts.Secure || c.opts.TlsConfig != nil { 145 config.Scheme = "https" 146 // We're going to support InsecureSkipVerify 147 config.HttpClient.Transport = newTransport(c.opts.TlsConfig) 148 } 149 150 // set timeout 151 if c.opts.Timeout > 0 { 152 config.HttpClient.Timeout = c.opts.Timeout 153 } 154 155 // set the config 156 c.config = config 157 158 // remove client 159 c.client = nil 160 161 // setup the client 162 c.Client() 163 } 164 165 func (c *consulRegistry) Init(opts ...registry.Option) error { 166 c.opts.Init(opts...) 167 configure(c) 168 return nil 169 } 170 171 func (c *consulRegistry) Deregister(s *registry.Service, opts ...registry.Option) error { 172 if len(s.Nodes) == 0 { 173 return errors.New("Require at least one node") 174 } 175 176 // delete our hash and time check of the service 177 c.Lock() 178 delete(c.register, s.Name) 179 delete(c.lastChecked, s.Name) 180 c.Unlock() 181 182 node := s.Nodes[0] 183 return c.Client().Agent().ServiceDeregister(node.Id) 184 } 185 186 func (c *consulRegistry) Register(s *registry.Service, opts ...registry.Option) error { 187 if len(s.Nodes) == 0 { 188 return errors.New("Require at least one node") 189 } 190 191 var regTCPCheck bool 192 var regInterval time.Duration 193 194 var options registry.Config 195 for _, o := range opts { 196 o(&options) 197 } 198 199 if c.opts.Context != nil { 200 if tcpCheckInterval, ok := c.opts.Context.Value("consul_tcp_check").(time.Duration); ok { 201 regTCPCheck = true 202 regInterval = tcpCheckInterval 203 } 204 } 205 206 // create hash of service; uint64 207 h, err := hash.Hash(s, nil) 208 if err != nil { 209 return err 210 } 211 212 // use first node 213 node := s.Nodes[0] 214 215 // get existing hash and last checked time 216 c.Lock() 217 v, ok := c.register[s.Name] 218 lastChecked := c.lastChecked[s.Name] 219 c.Unlock() 220 221 // if it's already registered and matches then just pass the check 222 if ok && v == h { 223 if options.TTL == time.Duration(0) { 224 // ensure that our service hasn't been deregistered by Consul 225 if time.Since(lastChecked) <= getDeregisterTTL(regInterval) { 226 return nil 227 } 228 services, _, err := c.Client().Health().Checks(s.Name, c.queryOptions) 229 if err == nil { 230 for _, v := range services { 231 if v.ServiceID == node.Id { 232 return nil 233 } 234 } 235 } 236 } else { 237 // if the err is nil we're all good, bail out 238 // if not, we don't know what the state is, so full re-register 239 if err := c.Client().Agent().PassTTL("service:"+node.Id, ""); err == nil { 240 return nil 241 } 242 } 243 } 244 245 // encode the tags 246 tags := encodeMetadata(node.Metadata) 247 tags = append(tags, encodeEndpoints(s.Endpoints)...) 248 tags = append(tags, encodeVersion(s.Version)...) 249 250 var check *consul.AgentServiceCheck 251 252 if regTCPCheck { 253 deregTTL := getDeregisterTTL(regInterval) 254 255 check = &consul.AgentServiceCheck{ 256 TCP: node.Address, 257 Interval: fmt.Sprintf("%v", regInterval), 258 DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL), 259 } 260 261 // if the TTL is greater than 0 create an associated check 262 } else if options.TTL > time.Duration(0) { 263 deregTTL := getDeregisterTTL(options.TTL) 264 265 check = &consul.AgentServiceCheck{ 266 TTL: fmt.Sprintf("%v", options.TTL), 267 DeregisterCriticalServiceAfter: fmt.Sprintf("%v", deregTTL), 268 } 269 } 270 271 host, pt, _ := net.SplitHostPort(node.Address) 272 if host == "" { 273 host = node.Address 274 } 275 port, _ := strconv.Atoi(pt) 276 277 // register the service 278 asr := &consul.AgentServiceRegistration{ 279 ID: node.Id, 280 Name: s.Name, 281 Tags: tags, 282 Port: port, 283 Address: host, 284 Check: check, 285 } 286 287 // Specify consul connect 288 if c.connect { 289 asr.Connect = &consul.AgentServiceConnect{ 290 Native: true, 291 } 292 } 293 294 if err := c.Client().Agent().ServiceRegister(asr); err != nil { 295 return err 296 } 297 298 // save our hash and time check of the service 299 c.Lock() 300 c.register[s.Name] = h 301 c.lastChecked[s.Name] = time.Now() 302 c.Unlock() 303 304 // if the TTL is 0 we don't mess with the checks 305 if options.TTL == time.Duration(0) { 306 return nil 307 } 308 309 c.opts.LocalServices = append(c.opts.LocalServices, s) 310 // pass the healthcheck 311 return c.Client().Agent().PassTTL("service:"+node.Id, "") 312 } 313 314 func (m *consulRegistry) LocalServices() []*registry.Service { 315 return m.opts.LocalServices 316 } 317 318 func (c *consulRegistry) GetService(name string) ([]*registry.Service, error) { 319 var rsp []*consul.ServiceEntry 320 var err error 321 322 // if we're connect enabled only get connect services 323 if c.connect { 324 rsp, _, err = c.Client().Health().Connect(name, "", false, c.queryOptions) 325 } else { 326 rsp, _, err = c.Client().Health().Service(name, "", false, c.queryOptions) 327 } 328 if err != nil { 329 return nil, err 330 } 331 332 serviceMap := map[string]*registry.Service{} 333 334 for _, s := range rsp { 335 if s.Service.Service != name { 336 continue 337 } 338 339 // version is now a tag 340 version, _ := decodeVersion(s.Service.Tags) 341 // service ID is now the node id 342 id := s.Service.ID 343 // key is always the version 344 key := version 345 346 // address is service address 347 address := s.Service.Address 348 349 // use node address 350 if len(address) == 0 { 351 address = s.Node.Address 352 } 353 354 svc, ok := serviceMap[key] 355 if !ok { 356 svc = ®istry.Service{ 357 Endpoints: decodeEndpoints(s.Service.Tags), 358 Name: s.Service.Service, 359 Version: version, 360 } 361 serviceMap[key] = svc 362 } 363 364 var del bool 365 366 for _, check := range s.Checks { 367 // delete the node if the status is critical 368 if check.Status == "critical" { 369 del = true 370 break 371 } 372 } 373 374 // if delete then skip the node 375 if del { 376 continue 377 } 378 379 svc.Nodes = append(svc.Nodes, ®istry.Node{ 380 Id: id, 381 Address: mnet.HostPort(address, s.Service.Port), 382 Metadata: decodeMetadata(s.Service.Tags), 383 }) 384 } 385 386 var services []*registry.Service 387 for _, service := range serviceMap { 388 services = append(services, service) 389 } 390 return services, nil 391 } 392 393 func (c *consulRegistry) ListServices() ([]*registry.Service, error) { 394 rsp, _, err := c.Client().Catalog().Services(c.queryOptions) 395 if err != nil { 396 return nil, err 397 } 398 399 var services []*registry.Service 400 401 for service := range rsp { 402 services = append(services, ®istry.Service{Name: service}) 403 } 404 405 return services, nil 406 } 407 408 func (c *consulRegistry) Watcher(opts ...registry.WatchOptions) (registry.Watcher, error) { 409 return newConsulWatcher(c, opts...) 410 } 411 412 func (c *consulRegistry) String() string { 413 return c.opts.Name 414 } 415 416 func (c *consulRegistry) Config() *registry.Config { 417 return c.opts 418 } 419 420 func (c *consulRegistry) Client() *consul.Client { 421 if c.client != nil { 422 return c.client 423 } 424 425 for _, addr := range c.Address { 426 // set the address 427 c.config.Address = addr 428 429 // create a new client 430 tmpClient, _ := consul.NewClient(c.config) 431 432 // test the client 433 _, err := tmpClient.Agent().Host() 434 if err != nil { 435 continue 436 } 437 438 // set the client 439 c.client = tmpClient 440 return c.client 441 } 442 443 // set the default 444 c.client, _ = consul.NewClient(c.config) 445 446 // return the client 447 return c.client 448 }