github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/registry/mdns/mdns.go (about) 1 package mdns 2 3 import ( 4 "bytes" 5 "compress/zlib" 6 "context" 7 "encoding/hex" 8 "encoding/json" 9 "fmt" 10 "io/ioutil" 11 "net" 12 "strconv" 13 "strings" 14 "sync" 15 "time" 16 17 "github.com/google/uuid" 18 "github.com/volts-dev/volts/internal/mdns" 19 "github.com/volts-dev/volts/registry" 20 ) 21 22 var log = registry.Logger() 23 24 type ( 25 mdnsRegistry struct { 26 sync.Mutex 27 mtx sync.RWMutex 28 domain string // the mdns domain 29 config *registry.Config 30 services map[string][]*mdnsEntry 31 watchers map[string]*mdnsWatcher // watchers 32 listener chan *mdns.ServiceEntry // listener 33 } 34 35 mdnsEntry struct { 36 id string 37 node *mdns.Server 38 } 39 40 mdnsWatcher struct { 41 id string 42 wo *registry.WatchConfig 43 ch chan *mdns.ServiceEntry 44 exit chan struct{} 45 // the mdns domain 46 domain string 47 // the registry 48 registry *mdnsRegistry 49 } 50 51 mdnsTxt struct { 52 Service string 53 Version string 54 Endpoints []*registry.Endpoint 55 Metadata map[string]string 56 } 57 ) 58 59 func init() { 60 registry.Register("mdns", New) 61 } 62 63 // NewMdnsRegistry 64 func New(opts ...registry.Option) registry.IRegistry { 65 var defaultOpts []registry.Option 66 defaultOpts = append(defaultOpts, 67 registry.WithName("mdns"), 68 registry.Timeout(time.Millisecond*100), 69 ) 70 cfg := registry.NewConfig(append(defaultOpts, opts...)...) 71 // set the domain 72 domain := mdnsDomain 73 d, ok := cfg.Context.Value("mdns.domain").(string) 74 if ok { 75 domain = d 76 } 77 78 reg := &mdnsRegistry{ 79 domain: domain, 80 config: cfg, 81 services: make(map[string][]*mdnsEntry), 82 watchers: make(map[string]*mdnsWatcher), 83 listener: make(chan *mdns.ServiceEntry), 84 } 85 reg.config.Name = reg.String() 86 return reg 87 } 88 89 func encode(txt *mdnsTxt) ([]string, error) { 90 b, err := json.Marshal(txt) 91 if err != nil { 92 return nil, err 93 } 94 95 var buf bytes.Buffer 96 defer buf.Reset() 97 98 w := zlib.NewWriter(&buf) 99 if _, err := w.Write(b); err != nil { 100 return nil, err 101 } 102 w.Close() 103 104 encoded := hex.EncodeToString(buf.Bytes()) 105 106 // individual txt limit 107 if len(encoded) <= 255 { 108 return []string{encoded}, nil 109 } 110 111 // split encoded string 112 var record []string 113 114 for len(encoded) > 255 { 115 record = append(record, encoded[:255]) 116 encoded = encoded[255:] 117 } 118 119 record = append(record, encoded) 120 121 return record, nil 122 } 123 124 func decode(record []string) (*mdnsTxt, error) { 125 encoded := strings.Join(record, "") 126 127 hr, err := hex.DecodeString(encoded) 128 if err != nil { 129 return nil, err 130 } 131 132 br := bytes.NewReader(hr) 133 zr, err := zlib.NewReader(br) 134 if err != nil { 135 return nil, err 136 } 137 138 rbuf, err := ioutil.ReadAll(zr) 139 if err != nil { 140 return nil, err 141 } 142 143 var txt *mdnsTxt 144 145 if err := json.Unmarshal(rbuf, &txt); err != nil { 146 return nil, err 147 } 148 149 return txt, nil 150 } 151 152 func (self *mdnsRegistry) Init(opts ...registry.Option) error { 153 for _, o := range opts { 154 o(self.config) 155 } 156 return nil 157 } 158 159 func (self *mdnsRegistry) Config() *registry.Config { 160 return self.config 161 } 162 163 func (m *mdnsRegistry) Register(service *registry.Service, opts ...registry.Option) error { 164 m.Lock() 165 defer m.Unlock() 166 167 entries, ok := m.services[service.Name] 168 // first entry, create wildcard used for list queries 169 if !ok { 170 s, err := mdns.NewMDNSService( 171 service.Name, 172 "_services", 173 m.domain+".", 174 "", 175 9999, 176 []net.IP{net.ParseIP("0.0.0.0")}, 177 nil, 178 ) 179 if err != nil { 180 return err 181 } 182 183 srv, err := mdns.NewServer(&mdns.Config{Zone: &mdns.DNSSDService{MDNSService: s}}) 184 if err != nil { 185 return err 186 } 187 188 // append the wildcard entry 189 entries = append(entries, &mdnsEntry{id: "*", node: srv}) 190 } 191 192 var gerr error 193 194 for _, node := range service.Nodes { 195 var seen bool 196 var e *mdnsEntry 197 198 for _, entry := range entries { 199 if node.Id == entry.id { 200 seen = true 201 e = entry 202 break 203 } 204 } 205 206 // already registered, continue 207 if seen { 208 continue 209 // doesn't exist 210 } else { 211 e = &mdnsEntry{} 212 } 213 214 txt, err := encode(&mdnsTxt{ 215 Service: service.Name, 216 Version: service.Version, 217 Endpoints: service.Endpoints, 218 Metadata: node.Metadata, 219 }) 220 221 if err != nil { 222 gerr = err 223 continue 224 } 225 226 host, pt, err := net.SplitHostPort(node.Address) 227 if err != nil { 228 gerr = err 229 continue 230 } 231 port, _ := strconv.Atoi(pt) 232 233 //if logger.GetLevel()=={ 234 log.Dbgf("[mdns] registry create new service with ip: %s for: %s", net.ParseIP(host).String(), host) 235 //} 236 // we got here, new node 237 s, err := mdns.NewMDNSService( 238 node.Id, 239 service.Name, 240 m.domain+".", 241 "", 242 port, 243 []net.IP{net.ParseIP(host)}, 244 txt, 245 ) 246 if err != nil { 247 gerr = err 248 continue 249 } 250 251 srv, err := mdns.NewServer(&mdns.Config{Zone: s, LocalhostChecking: true}) 252 if err != nil { 253 gerr = err 254 continue 255 } 256 257 e.id = node.Id 258 e.node = srv 259 entries = append(entries, e) 260 } 261 262 if gerr == nil { 263 // save 264 m.services[service.Name] = entries 265 m.config.LocalServices = append(m.config.LocalServices, service) 266 } 267 268 return gerr 269 } 270 271 func (m *mdnsRegistry) Deregister(service *registry.Service, opts ...registry.Option) error { 272 m.Lock() 273 defer m.Unlock() 274 275 var newEntries []*mdnsEntry 276 277 // loop existing entries, check if any match, shutdown those that do 278 for _, entry := range m.services[service.Name] { 279 var remove bool 280 281 for _, node := range service.Nodes { 282 if node.Id == entry.id { 283 entry.node.Shutdown() 284 remove = true 285 break 286 } 287 } 288 289 // keep it? 290 if !remove { 291 newEntries = append(newEntries, entry) 292 } 293 } 294 295 // last entry is the wildcard for list queries. Remove it. 296 if len(newEntries) == 1 && newEntries[0].id == "*" { 297 newEntries[0].node.Shutdown() 298 delete(m.services, service.Name) 299 } else { 300 m.services[service.Name] = newEntries 301 } 302 303 return nil 304 } 305 306 func (m *mdnsRegistry) LocalServices() []*registry.Service { 307 return m.config.LocalServices 308 } 309 310 func (m *mdnsRegistry) GetService(service string) ([]*registry.Service, error) { 311 serviceMap := make(map[string]*registry.Service) 312 entries := make(chan *mdns.ServiceEntry, 10) 313 done := make(chan bool) 314 315 p := mdns.DefaultParams(service) 316 // set context with timeout 317 var cancel context.CancelFunc 318 p.Context, cancel = context.WithTimeout(context.Background(), m.config.Timeout) 319 defer cancel() 320 // set entries channel 321 p.Entries = entries 322 // set the domain 323 p.Domain = m.domain 324 325 go func() { 326 for { 327 select { 328 case e := <-entries: 329 // list record so skip 330 if p.Service == "_services" { 331 continue 332 } 333 if p.Domain != m.domain { 334 continue 335 } 336 if e.TTL == 0 { 337 continue 338 } 339 340 txt, err := decode(e.InfoFields) 341 if err != nil { 342 continue 343 } 344 345 if txt.Service != service { 346 continue 347 } 348 349 s, ok := serviceMap[txt.Version] 350 if !ok { 351 s = ®istry.Service{ 352 Name: txt.Service, 353 Version: txt.Version, 354 Endpoints: txt.Endpoints, 355 } 356 } 357 addr := "" 358 // prefer ipv4 addrs 359 if len(e.AddrV4) > 0 { 360 addr = e.AddrV4.String() 361 // else use ipv6 362 } else if len(e.AddrV6) > 0 { 363 addr = "[" + e.AddrV6.String() + "]" 364 } else { 365 //if logger.V(logger.InfoLevel, logger.DefaultLogger) { 366 log.Infof("[mdns]: invalid endpoint received: %v", e) 367 //} 368 continue 369 } 370 s.Nodes = append(s.Nodes, ®istry.Node{ 371 Id: strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+"."), 372 Address: fmt.Sprintf("%s:%d", addr, e.Port), 373 Metadata: txt.Metadata, 374 }) 375 376 serviceMap[txt.Version] = s 377 case <-p.Context.Done(): 378 close(done) 379 return 380 } 381 } 382 }() 383 384 // execute the query 385 if err := mdns.Query(p); err != nil { 386 return nil, err 387 } 388 389 // wait for completion 390 <-done 391 392 // create list and return 393 services := make([]*registry.Service, 0, len(serviceMap)) 394 395 for _, service := range serviceMap { 396 services = append(services, service) 397 } 398 399 return services, nil 400 } 401 402 func (m *mdnsRegistry) ListServices() ([]*registry.Service, error) { 403 serviceMap := make(map[string]bool) 404 entries := make(chan *mdns.ServiceEntry, 10) 405 done := make(chan bool) 406 407 p := mdns.DefaultParams("_services") 408 // set context with timeout 409 var cancel context.CancelFunc 410 p.Context, cancel = context.WithTimeout(context.Background(), m.config.Timeout) 411 defer cancel() 412 // set entries channel 413 p.Entries = entries 414 // set domain 415 p.Domain = m.domain 416 417 var services []*registry.Service 418 419 go func() { 420 for { 421 select { 422 case e := <-entries: 423 if e.TTL == 0 { 424 continue 425 } 426 if !strings.HasSuffix(e.Name, p.Domain+".") { 427 continue 428 } 429 name := strings.TrimSuffix(e.Name, "."+p.Service+"."+p.Domain+".") 430 if !serviceMap[name] { 431 serviceMap[name] = true 432 services = append(services, ®istry.Service{Name: name}) 433 } 434 case <-p.Context.Done(): 435 close(done) 436 return 437 } 438 } 439 }() 440 441 // execute query 442 if err := mdns.Query(p); err != nil { 443 return nil, err 444 } 445 446 // wait till done 447 <-done 448 449 return services, nil 450 } 451 452 func (m *mdnsRegistry) Watcher(opts ...registry.WatchOptions) (registry.Watcher, error) { 453 wcfg := ®istry.WatchConfig{} 454 for _, opt := range opts { 455 opt(wcfg) 456 } 457 458 md := &mdnsWatcher{ 459 id: uuid.New().String(), 460 wo: wcfg, 461 ch: make(chan *mdns.ServiceEntry, 32), 462 exit: make(chan struct{}), 463 domain: m.domain, 464 registry: m, 465 } 466 467 m.mtx.Lock() 468 defer m.mtx.Unlock() 469 470 // save the watcher 471 m.watchers[md.id] = md 472 473 // check of the listener exists 474 if m.listener != nil { 475 return md, nil 476 } 477 478 // start the listener 479 go func() { 480 // go to infinity 481 for { 482 m.mtx.Lock() 483 484 // just return if there are no watchers 485 if len(m.watchers) == 0 { 486 m.listener = nil 487 m.mtx.Unlock() 488 return 489 } 490 491 // check existing listener 492 if m.listener != nil { 493 m.mtx.Unlock() 494 return 495 } 496 497 // reset the listener 498 exit := make(chan struct{}) 499 ch := make(chan *mdns.ServiceEntry, 32) 500 m.listener = ch 501 502 m.mtx.Unlock() 503 504 // send messages to the watchers 505 go func() { 506 send := func(w *mdnsWatcher, e *mdns.ServiceEntry) { 507 select { 508 case w.ch <- e: 509 default: 510 } 511 } 512 513 for { 514 select { 515 case <-exit: 516 return 517 case e, ok := <-ch: 518 if !ok { 519 return 520 } 521 m.mtx.RLock() 522 // send service entry to all watchers 523 for _, w := range m.watchers { 524 send(w, e) 525 } 526 m.mtx.RUnlock() 527 } 528 } 529 530 }() 531 532 // start listening, blocking call 533 mdns.Listen(ch, exit) 534 535 // mdns.Listen has unblocked 536 // kill the saved listener 537 m.mtx.Lock() 538 m.listener = nil 539 close(ch) 540 m.mtx.Unlock() 541 } 542 }() 543 544 return md, nil 545 } 546 547 func (m *mdnsRegistry) String() string { 548 return m.config.Name 549 } 550 551 func (m *mdnsWatcher) Next() (*registry.Result, error) { 552 for { 553 select { 554 case e := <-m.ch: 555 txt, err := decode(e.InfoFields) 556 if err != nil { 557 continue 558 } 559 560 if len(txt.Service) == 0 || len(txt.Version) == 0 { 561 continue 562 } 563 564 // Filter watch options 565 // wo.Service: Only keep services we care about 566 if len(m.wo.Service) > 0 && txt.Service != m.wo.Service { 567 continue 568 } 569 var action string 570 if e.TTL == 0 { 571 action = "delete" 572 } else { 573 action = "create" 574 } 575 576 service := ®istry.Service{ 577 Name: txt.Service, 578 Version: txt.Version, 579 Endpoints: txt.Endpoints, 580 } 581 582 // skip anything without the domain we care about 583 suffix := fmt.Sprintf(".%s.%s.", service.Name, m.domain) 584 if !strings.HasSuffix(e.Name, suffix) { 585 continue 586 } 587 588 var addr string 589 if len(e.AddrV4) > 0 { 590 addr = e.AddrV4.String() 591 } else if len(e.AddrV6) > 0 { 592 addr = "[" + e.AddrV6.String() + "]" 593 } else { 594 addr = e.Addr.String() 595 } 596 597 service.Nodes = append(service.Nodes, ®istry.Node{ 598 Id: strings.TrimSuffix(e.Name, suffix), 599 Address: fmt.Sprintf("%s:%d", addr, e.Port), 600 Metadata: txt.Metadata, 601 }) 602 603 return ®istry.Result{ 604 Action: action, 605 Service: service, 606 }, nil 607 case <-m.exit: 608 return nil, registry.ErrWatcherStopped 609 } 610 } 611 } 612 613 func (m *mdnsWatcher) Stop() { 614 select { 615 case <-m.exit: 616 return 617 default: 618 close(m.exit) 619 // remove self from the registry 620 m.registry.mtx.Lock() 621 delete(m.registry.watchers, m.id) 622 m.registry.mtx.Unlock() 623 } 624 }