go-micro.dev/v5@v5.12.0/util/mdns/client.go (about) 1 package mdns 2 3 import ( 4 "context" 5 "fmt" 6 "net" 7 "strings" 8 "sync" 9 "time" 10 11 "github.com/miekg/dns" 12 "go-micro.dev/v5/logger" 13 "golang.org/x/net/ipv4" 14 "golang.org/x/net/ipv6" 15 ) 16 17 // ServiceEntry is returned after we query for a service. 18 type ServiceEntry struct { 19 Name string 20 Host string 21 Info string 22 AddrV4 net.IP 23 AddrV6 net.IP 24 InfoFields []string 25 26 Addr net.IP // @Deprecated 27 28 Port int 29 TTL int 30 Type uint16 31 32 hasTXT bool 33 sent bool 34 } 35 36 // complete is used to check if we have all the info we need. 37 func (s *ServiceEntry) complete() bool { 38 return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT 39 } 40 41 // QueryParam is used to customize how a Lookup is performed. 42 type QueryParam struct { 43 Context context.Context // Context 44 Interface *net.Interface // Multicast interface to use 45 Entries chan<- *ServiceEntry // Entries Channel 46 Service string // Service to lookup 47 Domain string // Lookup domain, default "local" 48 Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided 49 Type uint16 // Lookup type, defaults to dns.TypePTR 50 WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC 51 } 52 53 // DefaultParams is used to return a default set of QueryParam's. 54 func DefaultParams(service string) *QueryParam { 55 return &QueryParam{ 56 Service: service, 57 Domain: "local", 58 Timeout: time.Second, 59 Entries: make(chan *ServiceEntry), 60 WantUnicastResponse: false, // TODO(reddaly): Change this default. 61 } 62 } 63 64 // Query looks up a given service, in a domain, waiting at most 65 // for a timeout before finishing the query. The results are streamed 66 // to a channel. Sends will not block, so clients should make sure to 67 // either read or buffer. 68 func Query(params *QueryParam) error { 69 // Create a new client 70 client, err := newClient() 71 if err != nil { 72 return err 73 } 74 defer client.Close() 75 76 // Set the multicast interface 77 if params.Interface != nil { 78 if err := client.setInterface(params.Interface, false); err != nil { 79 return err 80 } 81 } 82 83 // Ensure defaults are set 84 if params.Domain == "" { 85 params.Domain = "local" 86 } 87 88 if params.Context == nil { 89 if params.Timeout == 0 { 90 params.Timeout = time.Second 91 } 92 params.Context, _ = context.WithTimeout(context.Background(), params.Timeout) 93 if err != nil { 94 return err 95 } 96 } 97 98 // Run the query 99 return client.query(params) 100 } 101 102 // Listen listens indefinitely for multicast updates. 103 func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error { 104 // Create a new client 105 client, err := newClient() 106 if err != nil { 107 return err 108 } 109 defer client.Close() 110 111 client.setInterface(nil, true) 112 113 // Start listening for response packets 114 msgCh := make(chan *dns.Msg, 32) 115 116 go client.recv(client.ipv4UnicastConn, msgCh) 117 go client.recv(client.ipv6UnicastConn, msgCh) 118 go client.recv(client.ipv4MulticastConn, msgCh) 119 go client.recv(client.ipv6MulticastConn, msgCh) 120 121 ip := make(map[string]*ServiceEntry) 122 123 for { 124 select { 125 case <-exit: 126 return nil 127 case <-client.closedCh: 128 return nil 129 case m := <-msgCh: 130 e := messageToEntry(m, ip) 131 if e == nil { 132 continue 133 } 134 135 // Check if this entry is complete 136 if e.complete() { 137 if e.sent { 138 continue 139 } 140 e.sent = true 141 entries <- e 142 ip = make(map[string]*ServiceEntry) 143 } else { 144 // Fire off a node specific query 145 m := new(dns.Msg) 146 m.SetQuestion(e.Name, dns.TypePTR) 147 m.RecursionDesired = false 148 if err := client.sendQuery(m); err != nil { 149 logger.Logf(logger.ErrorLevel, "[mdns] failed to query instance %s: %v", e.Name, err) 150 } 151 } 152 } 153 } 154 155 return nil 156 } 157 158 // Lookup is the same as Query, however it uses all the default parameters. 159 func Lookup(service string, entries chan<- *ServiceEntry) error { 160 params := DefaultParams(service) 161 params.Entries = entries 162 return Query(params) 163 } 164 165 // Client provides a query interface that can be used to 166 // search for service providers using mDNS. 167 type client struct { 168 ipv4UnicastConn *net.UDPConn 169 ipv6UnicastConn *net.UDPConn 170 171 ipv4MulticastConn *net.UDPConn 172 ipv6MulticastConn *net.UDPConn 173 174 closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used. 175 closeLock sync.Mutex 176 177 closed bool 178 } 179 180 // NewClient creates a new mdns Client that can be used to query 181 // for records. 182 func newClient() (*client, error) { 183 // TODO(reddaly): At least attempt to bind to the port required in the spec. 184 // Create a IPv4 listener 185 uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) 186 uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) 187 if err4 != nil && err6 != nil { 188 logger.Logf(logger.ErrorLevel, "[mdns] failed to bind to udp port: %v %v", err4, err6) 189 } 190 191 if uconn4 == nil && uconn6 == nil { 192 return nil, fmt.Errorf("failed to bind to any unicast udp port") 193 } 194 195 if uconn4 == nil { 196 uconn4 = &net.UDPConn{} 197 } 198 199 if uconn6 == nil { 200 uconn6 = &net.UDPConn{} 201 } 202 203 mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) 204 mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) 205 if err4 != nil && err6 != nil { 206 logger.Logf(logger.ErrorLevel, "[mdns] failed to bind to udp port: %v %v", err4, err6) 207 } 208 209 if mconn4 == nil && mconn6 == nil { 210 return nil, fmt.Errorf("failed to bind to any multicast udp port") 211 } 212 213 if mconn4 == nil { 214 mconn4 = &net.UDPConn{} 215 } 216 217 if mconn6 == nil { 218 mconn6 = &net.UDPConn{} 219 } 220 221 p1 := ipv4.NewPacketConn(mconn4) 222 p2 := ipv6.NewPacketConn(mconn6) 223 p1.SetMulticastLoopback(true) 224 p2.SetMulticastLoopback(true) 225 226 ifaces, err := net.Interfaces() 227 if err != nil { 228 return nil, err 229 } 230 231 var errCount1, errCount2 int 232 233 for _, iface := range ifaces { 234 if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { 235 errCount1++ 236 } 237 if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { 238 errCount2++ 239 } 240 } 241 242 if len(ifaces) == errCount1 && len(ifaces) == errCount2 { 243 return nil, fmt.Errorf("failed to join multicast group on all interfaces") 244 } 245 246 c := &client{ 247 ipv4MulticastConn: mconn4, 248 ipv6MulticastConn: mconn6, 249 ipv4UnicastConn: uconn4, 250 ipv6UnicastConn: uconn6, 251 closedCh: make(chan struct{}), 252 } 253 return c, nil 254 } 255 256 // Close is used to cleanup the client. 257 func (c *client) Close() error { 258 c.closeLock.Lock() 259 defer c.closeLock.Unlock() 260 261 if c.closed { 262 return nil 263 } 264 c.closed = true 265 266 close(c.closedCh) 267 268 if c.ipv4UnicastConn != nil { 269 c.ipv4UnicastConn.Close() 270 } 271 if c.ipv6UnicastConn != nil { 272 c.ipv6UnicastConn.Close() 273 } 274 if c.ipv4MulticastConn != nil { 275 c.ipv4MulticastConn.Close() 276 } 277 if c.ipv6MulticastConn != nil { 278 c.ipv6MulticastConn.Close() 279 } 280 281 return nil 282 } 283 284 // setInterface is used to set the query interface, uses system 285 // default if not provided. 286 func (c *client) setInterface(iface *net.Interface, loopback bool) error { 287 p := ipv4.NewPacketConn(c.ipv4UnicastConn) 288 if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { 289 return err 290 } 291 p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) 292 if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { 293 return err 294 } 295 p = ipv4.NewPacketConn(c.ipv4MulticastConn) 296 if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { 297 return err 298 } 299 p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) 300 if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { 301 return err 302 } 303 304 if loopback { 305 p.SetMulticastLoopback(true) 306 p2.SetMulticastLoopback(true) 307 } 308 309 return nil 310 } 311 312 // query is used to perform a lookup and stream results. 313 func (c *client) query(params *QueryParam) error { 314 // Create the service name 315 serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) 316 317 // Start listening for response packets 318 msgCh := make(chan *dns.Msg, 32) 319 go c.recv(c.ipv4UnicastConn, msgCh) 320 go c.recv(c.ipv6UnicastConn, msgCh) 321 go c.recv(c.ipv4MulticastConn, msgCh) 322 go c.recv(c.ipv6MulticastConn, msgCh) 323 324 // Send the query 325 m := new(dns.Msg) 326 if params.Type == dns.TypeNone { 327 m.SetQuestion(serviceAddr, dns.TypePTR) 328 } else { 329 m.SetQuestion(serviceAddr, params.Type) 330 } 331 // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question 332 // Section 333 // 334 // In the Question Section of a Multicast DNS query, the top bit of the qclass 335 // field is used to indicate that unicast responses are preferred for this 336 // particular question. (See Section 5.4.) 337 if params.WantUnicastResponse { 338 m.Question[0].Qclass |= 1 << 15 339 } 340 m.RecursionDesired = false 341 if err := c.sendQuery(m); err != nil { 342 return err 343 } 344 345 // Map the in-progress responses 346 inprogress := make(map[string]*ServiceEntry) 347 348 for { 349 select { 350 case resp := <-msgCh: 351 inp := messageToEntry(resp, inprogress) 352 353 if inp == nil { 354 continue 355 } 356 if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name { 357 // discard anything which we've not asked for 358 continue 359 } 360 361 // Check if this entry is complete 362 if inp.complete() { 363 if inp.sent { 364 continue 365 } 366 367 inp.sent = true 368 select { 369 case params.Entries <- inp: 370 case <-params.Context.Done(): 371 return nil 372 } 373 } else { 374 // Fire off a node specific query 375 m := new(dns.Msg) 376 m.SetQuestion(inp.Name, inp.Type) 377 m.RecursionDesired = false 378 if err := c.sendQuery(m); err != nil { 379 logger.Logf(logger.ErrorLevel, "[mdns] failed to query instance %s: %v", inp.Name, err) 380 } 381 } 382 case <-params.Context.Done(): 383 return nil 384 } 385 } 386 } 387 388 // sendQuery is used to multicast a query out. 389 func (c *client) sendQuery(q *dns.Msg) error { 390 buf, err := q.Pack() 391 if err != nil { 392 return err 393 } 394 if c.ipv4UnicastConn != nil { 395 c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr) 396 } 397 if c.ipv6UnicastConn != nil { 398 c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr) 399 } 400 return nil 401 } 402 403 // recv is used to receive until we get a shutdown. 404 func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) { 405 if l == nil { 406 return 407 } 408 buf := make([]byte, 65536) 409 for { 410 c.closeLock.Lock() 411 if c.closed { 412 c.closeLock.Unlock() 413 return 414 } 415 c.closeLock.Unlock() 416 n, err := l.Read(buf) 417 if err != nil { 418 continue 419 } 420 msg := new(dns.Msg) 421 if err := msg.Unpack(buf[:n]); err != nil { 422 continue 423 } 424 select { 425 case msgCh <- msg: 426 case <-c.closedCh: 427 return 428 } 429 } 430 } 431 432 // ensureName is used to ensure the named node is in progress. 433 func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry { 434 if inp, ok := inprogress[name]; ok { 435 return inp 436 } 437 inp := &ServiceEntry{ 438 Name: name, 439 Type: typ, 440 } 441 inprogress[name] = inp 442 return inp 443 } 444 445 // alias is used to setup an alias between two entries. 446 func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) { 447 srcEntry := ensureName(inprogress, src, typ) 448 inprogress[dst] = srcEntry 449 } 450 451 func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry { 452 var inp *ServiceEntry 453 454 for _, answer := range append(m.Answer, m.Extra...) { 455 // TODO(reddaly): Check that response corresponds to serviceAddr? 456 switch rr := answer.(type) { 457 case *dns.PTR: 458 // Create new entry for this 459 inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype) 460 if inp.complete() { 461 continue 462 } 463 case *dns.SRV: 464 // Check for a target mismatch 465 if rr.Target != rr.Hdr.Name { 466 alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype) 467 } 468 469 // Get the port 470 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 471 if inp.complete() { 472 continue 473 } 474 inp.Host = rr.Target 475 inp.Port = int(rr.Port) 476 case *dns.TXT: 477 // Pull out the txt 478 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 479 if inp.complete() { 480 continue 481 } 482 inp.Info = strings.Join(rr.Txt, "|") 483 inp.InfoFields = rr.Txt 484 inp.hasTXT = true 485 case *dns.A: 486 // Pull out the IP 487 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 488 if inp.complete() { 489 continue 490 } 491 inp.Addr = rr.A // @Deprecated 492 inp.AddrV4 = rr.A 493 case *dns.AAAA: 494 // Pull out the IP 495 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 496 if inp.complete() { 497 continue 498 } 499 inp.Addr = rr.AAAA // @Deprecated 500 inp.AddrV6 = rr.AAAA 501 } 502 503 if inp != nil { 504 inp.TTL = int(answer.Header().Ttl) 505 } 506 } 507 508 return inp 509 }