github.com/micro/go-micro/v2@v2.9.1/util/mdns/client.go (about) 1 package mdns 2 3 import ( 4 "context" 5 "fmt" 6 "log" 7 "net" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/miekg/dns" 13 "golang.org/x/net/ipv4" 14 "golang.org/x/net/ipv6" 15 ) 16 17 // ServiceEntry is returned after we query for a service 18 type ServiceEntry struct { 19 Name string 20 Host string 21 AddrV4 net.IP 22 AddrV6 net.IP 23 Port int 24 Info string 25 InfoFields []string 26 TTL int 27 Type uint16 28 29 Addr net.IP // @Deprecated 30 31 hasTXT bool 32 sent bool 33 } 34 35 // complete is used to check if we have all the info we need 36 func (s *ServiceEntry) complete() bool { 37 38 return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT 39 } 40 41 // QueryParam is used to customize how a Lookup is performed 42 type QueryParam struct { 43 Service string // Service to lookup 44 Domain string // Lookup domain, default "local" 45 Type uint16 // Lookup type, defaults to dns.TypePTR 46 Context context.Context // Context 47 Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided 48 Interface *net.Interface // Multicast interface to use 49 Entries chan<- *ServiceEntry // Entries Channel 50 WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC 51 } 52 53 // DefaultParams is used to return a default set of QueryParam's 54 func DefaultParams(service string) *QueryParam { 55 return &QueryParam{ 56 Service: service, 57 Domain: "local", 58 Timeout: time.Second, 59 Entries: make(chan *ServiceEntry), 60 WantUnicastResponse: false, // TODO(reddaly): Change this default. 61 } 62 } 63 64 // Query looks up a given service, in a domain, waiting at most 65 // for a timeout before finishing the query. The results are streamed 66 // to a channel. Sends will not block, so clients should make sure to 67 // either read or buffer. 68 func Query(params *QueryParam) error { 69 // Create a new client 70 client, err := newClient() 71 if err != nil { 72 return err 73 } 74 defer client.Close() 75 76 // Set the multicast interface 77 if params.Interface != nil { 78 if err := client.setInterface(params.Interface, false); err != nil { 79 return err 80 } 81 } 82 83 // Ensure defaults are set 84 if params.Domain == "" { 85 params.Domain = "local" 86 } 87 88 if params.Context == nil { 89 if params.Timeout == 0 { 90 params.Timeout = time.Second 91 } 92 params.Context, _ = context.WithTimeout(context.Background(), params.Timeout) 93 if err != nil { 94 return err 95 } 96 } 97 98 // Run the query 99 return client.query(params) 100 } 101 102 // Listen listens indefinitely for multicast updates 103 func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error { 104 // Create a new client 105 client, err := newClient() 106 if err != nil { 107 return err 108 } 109 defer client.Close() 110 111 client.setInterface(nil, true) 112 113 // Start listening for response packets 114 msgCh := make(chan *dns.Msg, 32) 115 116 go client.recv(client.ipv4UnicastConn, msgCh) 117 go client.recv(client.ipv6UnicastConn, msgCh) 118 go client.recv(client.ipv4MulticastConn, msgCh) 119 go client.recv(client.ipv6MulticastConn, msgCh) 120 121 ip := make(map[string]*ServiceEntry) 122 123 for { 124 select { 125 case <-exit: 126 return nil 127 case <-client.closedCh: 128 return nil 129 case m := <-msgCh: 130 e := messageToEntry(m, ip) 131 if e == nil { 132 continue 133 } 134 135 // Check if this entry is complete 136 if e.complete() { 137 if e.sent { 138 continue 139 } 140 e.sent = true 141 entries <- e 142 ip = make(map[string]*ServiceEntry) 143 } else { 144 // Fire off a node specific query 145 m := new(dns.Msg) 146 m.SetQuestion(e.Name, dns.TypePTR) 147 m.RecursionDesired = false 148 if err := client.sendQuery(m); err != nil { 149 log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err) 150 } 151 } 152 } 153 } 154 155 return nil 156 } 157 158 // Lookup is the same as Query, however it uses all the default parameters 159 func Lookup(service string, entries chan<- *ServiceEntry) error { 160 params := DefaultParams(service) 161 params.Entries = entries 162 return Query(params) 163 } 164 165 // Client provides a query interface that can be used to 166 // search for service providers using mDNS 167 type client struct { 168 ipv4UnicastConn *net.UDPConn 169 ipv6UnicastConn *net.UDPConn 170 171 ipv4MulticastConn *net.UDPConn 172 ipv6MulticastConn *net.UDPConn 173 174 closed bool 175 closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used. 176 closeLock sync.Mutex 177 } 178 179 // NewClient creates a new mdns Client that can be used to query 180 // for records 181 func newClient() (*client, error) { 182 // TODO(reddaly): At least attempt to bind to the port required in the spec. 183 // Create a IPv4 listener 184 uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) 185 uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) 186 if err4 != nil && err6 != nil { 187 log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) 188 } 189 190 if uconn4 == nil && uconn6 == nil { 191 return nil, fmt.Errorf("failed to bind to any unicast udp port") 192 } 193 194 if uconn4 == nil { 195 uconn4 = &net.UDPConn{} 196 } 197 198 if uconn6 == nil { 199 uconn6 = &net.UDPConn{} 200 } 201 202 mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) 203 mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) 204 if err4 != nil && err6 != nil { 205 log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) 206 } 207 208 if mconn4 == nil && mconn6 == nil { 209 return nil, fmt.Errorf("failed to bind to any multicast udp port") 210 } 211 212 if mconn4 == nil { 213 mconn4 = &net.UDPConn{} 214 } 215 216 if mconn6 == nil { 217 mconn6 = &net.UDPConn{} 218 } 219 220 p1 := ipv4.NewPacketConn(mconn4) 221 p2 := ipv6.NewPacketConn(mconn6) 222 p1.SetMulticastLoopback(true) 223 p2.SetMulticastLoopback(true) 224 225 ifaces, err := net.Interfaces() 226 if err != nil { 227 return nil, err 228 } 229 230 var errCount1, errCount2 int 231 232 for _, iface := range ifaces { 233 if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { 234 errCount1++ 235 } 236 if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { 237 errCount2++ 238 } 239 } 240 241 if len(ifaces) == errCount1 && len(ifaces) == errCount2 { 242 return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") 243 } 244 245 c := &client{ 246 ipv4MulticastConn: mconn4, 247 ipv6MulticastConn: mconn6, 248 ipv4UnicastConn: uconn4, 249 ipv6UnicastConn: uconn6, 250 closedCh: make(chan struct{}), 251 } 252 return c, nil 253 } 254 255 // Close is used to cleanup the client 256 func (c *client) Close() error { 257 c.closeLock.Lock() 258 defer c.closeLock.Unlock() 259 260 if c.closed { 261 return nil 262 } 263 c.closed = true 264 265 close(c.closedCh) 266 267 if c.ipv4UnicastConn != nil { 268 c.ipv4UnicastConn.Close() 269 } 270 if c.ipv6UnicastConn != nil { 271 c.ipv6UnicastConn.Close() 272 } 273 if c.ipv4MulticastConn != nil { 274 c.ipv4MulticastConn.Close() 275 } 276 if c.ipv6MulticastConn != nil { 277 c.ipv6MulticastConn.Close() 278 } 279 280 return nil 281 } 282 283 // setInterface is used to set the query interface, uses sytem 284 // default if not provided 285 func (c *client) setInterface(iface *net.Interface, loopback bool) error { 286 p := ipv4.NewPacketConn(c.ipv4UnicastConn) 287 if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { 288 return err 289 } 290 p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) 291 if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { 292 return err 293 } 294 p = ipv4.NewPacketConn(c.ipv4MulticastConn) 295 if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { 296 return err 297 } 298 p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) 299 if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { 300 return err 301 } 302 303 if loopback { 304 p.SetMulticastLoopback(true) 305 p2.SetMulticastLoopback(true) 306 } 307 308 return nil 309 } 310 311 // query is used to perform a lookup and stream results 312 func (c *client) query(params *QueryParam) error { 313 // Create the service name 314 serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) 315 316 // Start listening for response packets 317 msgCh := make(chan *dns.Msg, 32) 318 go c.recv(c.ipv4UnicastConn, msgCh) 319 go c.recv(c.ipv6UnicastConn, msgCh) 320 go c.recv(c.ipv4MulticastConn, msgCh) 321 go c.recv(c.ipv6MulticastConn, msgCh) 322 323 // Send the query 324 m := new(dns.Msg) 325 if params.Type == dns.TypeNone { 326 m.SetQuestion(serviceAddr, dns.TypePTR) 327 } else { 328 m.SetQuestion(serviceAddr, params.Type) 329 } 330 // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question 331 // Section 332 // 333 // In the Question Section of a Multicast DNS query, the top bit of the qclass 334 // field is used to indicate that unicast responses are preferred for this 335 // particular question. (See Section 5.4.) 336 if params.WantUnicastResponse { 337 m.Question[0].Qclass |= 1 << 15 338 } 339 m.RecursionDesired = false 340 if err := c.sendQuery(m); err != nil { 341 return err 342 } 343 344 // Map the in-progress responses 345 inprogress := make(map[string]*ServiceEntry) 346 347 for { 348 select { 349 case resp := <-msgCh: 350 inp := messageToEntry(resp, inprogress) 351 352 if inp == nil { 353 continue 354 } 355 if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name { 356 // discard anything which we've not asked for 357 continue 358 } 359 360 // Check if this entry is complete 361 if inp.complete() { 362 if inp.sent { 363 continue 364 } 365 366 inp.sent = true 367 select { 368 case params.Entries <- inp: 369 case <-params.Context.Done(): 370 return nil 371 } 372 } else { 373 // Fire off a node specific query 374 m := new(dns.Msg) 375 m.SetQuestion(inp.Name, inp.Type) 376 m.RecursionDesired = false 377 if err := c.sendQuery(m); err != nil { 378 log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) 379 } 380 } 381 case <-params.Context.Done(): 382 return nil 383 } 384 } 385 } 386 387 // sendQuery is used to multicast a query out 388 func (c *client) sendQuery(q *dns.Msg) error { 389 buf, err := q.Pack() 390 if err != nil { 391 return err 392 } 393 if c.ipv4UnicastConn != nil { 394 c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr) 395 } 396 if c.ipv6UnicastConn != nil { 397 c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr) 398 } 399 return nil 400 } 401 402 // recv is used to receive until we get a shutdown 403 func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) { 404 if l == nil { 405 return 406 } 407 buf := make([]byte, 65536) 408 for { 409 c.closeLock.Lock() 410 if c.closed { 411 c.closeLock.Unlock() 412 return 413 } 414 c.closeLock.Unlock() 415 n, err := l.Read(buf) 416 if err != nil { 417 continue 418 } 419 msg := new(dns.Msg) 420 if err := msg.Unpack(buf[:n]); err != nil { 421 continue 422 } 423 select { 424 case msgCh <- msg: 425 case <-c.closedCh: 426 return 427 } 428 } 429 } 430 431 // ensureName is used to ensure the named node is in progress 432 func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry { 433 if inp, ok := inprogress[name]; ok { 434 return inp 435 } 436 inp := &ServiceEntry{ 437 Name: name, 438 Type: typ, 439 } 440 inprogress[name] = inp 441 return inp 442 } 443 444 // alias is used to setup an alias between two entries 445 func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) { 446 srcEntry := ensureName(inprogress, src, typ) 447 inprogress[dst] = srcEntry 448 } 449 450 func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry { 451 var inp *ServiceEntry 452 453 for _, answer := range append(m.Answer, m.Extra...) { 454 // TODO(reddaly): Check that response corresponds to serviceAddr? 455 switch rr := answer.(type) { 456 case *dns.PTR: 457 // Create new entry for this 458 inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype) 459 if inp.complete() { 460 continue 461 } 462 case *dns.SRV: 463 // Check for a target mismatch 464 if rr.Target != rr.Hdr.Name { 465 alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype) 466 } 467 468 // Get the port 469 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 470 if inp.complete() { 471 continue 472 } 473 inp.Host = rr.Target 474 inp.Port = int(rr.Port) 475 case *dns.TXT: 476 // Pull out the txt 477 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 478 if inp.complete() { 479 continue 480 } 481 inp.Info = strings.Join(rr.Txt, "|") 482 inp.InfoFields = rr.Txt 483 inp.hasTXT = true 484 case *dns.A: 485 // Pull out the IP 486 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 487 if inp.complete() { 488 continue 489 } 490 inp.Addr = rr.A // @Deprecated 491 inp.AddrV4 = rr.A 492 case *dns.AAAA: 493 // Pull out the IP 494 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 495 if inp.complete() { 496 continue 497 } 498 inp.Addr = rr.AAAA // @Deprecated 499 inp.AddrV6 = rr.AAAA 500 } 501 502 if inp != nil { 503 inp.TTL = int(answer.Header().Ttl) 504 } 505 } 506 507 return inp 508 }