github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/internal/mdns/client.go (about) 1 package mdns 2 3 import ( 4 "context" 5 "fmt" 6 "log" 7 "net" 8 "strings" 9 "sync" 10 "time" 11 12 "github.com/miekg/dns" 13 "golang.org/x/net/ipv4" 14 "golang.org/x/net/ipv6" 15 ) 16 17 // ServiceEntry is returned after we query for a service 18 type ServiceEntry struct { 19 Name string 20 Host string 21 AddrV4 net.IP 22 AddrV6 net.IP 23 Port int 24 Info string 25 InfoFields []string 26 TTL int 27 Type uint16 28 29 Addr net.IP // @Deprecated 30 31 hasTXT bool 32 sent bool 33 } 34 35 // complete is used to check if we have all the info we need 36 func (s *ServiceEntry) complete() bool { 37 38 return (len(s.AddrV4) > 0 || len(s.AddrV6) > 0 || len(s.Addr) > 0) && s.Port != 0 && s.hasTXT 39 } 40 41 // QueryParam is used to customize how a Lookup is performed 42 type QueryParam struct { 43 Service string // Service to lookup 44 Domain string // Lookup domain, default "local" 45 Type uint16 // Lookup type, defaults to dns.TypePTR 46 Context context.Context // Context 47 Timeout time.Duration // Lookup timeout, default 1 second. Ignored if Context is provided 48 Interface *net.Interface // Multicast interface to use 49 Entries chan<- *ServiceEntry // Entries Channel 50 WantUnicastResponse bool // Unicast response desired, as per 5.4 in RFC 51 } 52 53 // DefaultParams is used to return a default set of QueryParam's 54 func DefaultParams(service string) *QueryParam { 55 return &QueryParam{ 56 Service: service, 57 Domain: "local", 58 Timeout: time.Second, 59 Entries: make(chan *ServiceEntry), 60 WantUnicastResponse: false, // TODO(reddaly): Change this default. 61 } 62 } 63 64 // Query looks up a given service, in a domain, waiting at most 65 // for a timeout before finishing the query. The results are streamed 66 // to a channel. Sends will not block, so clients should make sure to 67 // either read or buffer. 68 func Query(params *QueryParam) error { 69 // Create a new client 70 client, err := newClient() 71 if err != nil { 72 return err 73 } 74 defer client.Close() 75 76 // Set the multicast interface 77 if params.Interface != nil { 78 if err := client.setInterface(params.Interface, false); err != nil { 79 return err 80 } 81 } 82 83 // Ensure defaults are set 84 if params.Domain == "" { 85 params.Domain = "local" 86 } 87 88 if params.Context == nil { 89 if params.Timeout == 0 { 90 params.Timeout = time.Second 91 } 92 var cancel context.CancelFunc 93 params.Context, cancel = context.WithTimeout(context.Background(), params.Timeout) 94 defer cancel() 95 } 96 97 // Run the query 98 return client.query(params) 99 } 100 101 // Listen listens indefinitely for multicast updates 102 func Listen(entries chan<- *ServiceEntry, exit chan struct{}) error { 103 // Create a new client 104 client, err := newClient() 105 if err != nil { 106 return err 107 } 108 defer client.Close() 109 110 client.setInterface(nil, true) 111 112 // Start listening for response packets 113 msgCh := make(chan *dns.Msg, 32) 114 115 go client.recv(client.ipv4UnicastConn, msgCh) 116 go client.recv(client.ipv6UnicastConn, msgCh) 117 go client.recv(client.ipv4MulticastConn, msgCh) 118 go client.recv(client.ipv6MulticastConn, msgCh) 119 120 ip := make(map[string]*ServiceEntry) 121 122 for { 123 select { 124 case <-exit: 125 return nil 126 case <-client.closedCh: 127 return nil 128 case m := <-msgCh: 129 e := messageToEntry(m, ip) 130 if e == nil { 131 continue 132 } 133 134 // Check if this entry is complete 135 if e.complete() { 136 if e.sent { 137 continue 138 } 139 e.sent = true 140 entries <- e 141 ip = make(map[string]*ServiceEntry) 142 } else { 143 // Fire off a node specific query 144 m := new(dns.Msg) 145 m.SetQuestion(e.Name, dns.TypePTR) 146 m.RecursionDesired = false 147 if err := client.sendQuery(m); err != nil { 148 log.Printf("[ERR] mdns: Failed to query instance %s: %v", e.Name, err) 149 } 150 } 151 } 152 } 153 } 154 155 // Lookup is the same as Query, however it uses all the default parameters 156 func Lookup(service string, entries chan<- *ServiceEntry) error { 157 params := DefaultParams(service) 158 params.Entries = entries 159 return Query(params) 160 } 161 162 // Client provides a query interface that can be used to 163 // search for service providers using mDNS 164 type client struct { 165 ipv4UnicastConn *net.UDPConn 166 ipv6UnicastConn *net.UDPConn 167 168 ipv4MulticastConn *net.UDPConn 169 ipv6MulticastConn *net.UDPConn 170 171 closed bool 172 closedCh chan struct{} // TODO(reddaly): This doesn't appear to be used. 173 closeLock sync.Mutex 174 } 175 176 // NewClient creates a new mdns Client that can be used to query 177 // for records 178 func newClient() (*client, error) { 179 // TODO(reddaly): At least attempt to bind to the port required in the spec. 180 // Create a IPv4 listener 181 uconn4, err4 := net.ListenUDP("udp4", &net.UDPAddr{IP: net.IPv4zero, Port: 0}) 182 uconn6, err6 := net.ListenUDP("udp6", &net.UDPAddr{IP: net.IPv6zero, Port: 0}) 183 if err4 != nil && err6 != nil { 184 log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) 185 } 186 187 if uconn4 == nil && uconn6 == nil { 188 return nil, fmt.Errorf("failed to bind to any unicast udp port") 189 } 190 191 if uconn4 == nil { 192 uconn4 = &net.UDPConn{} 193 } 194 195 if uconn6 == nil { 196 uconn6 = &net.UDPConn{} 197 } 198 199 mconn4, err4 := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) 200 mconn6, err6 := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) 201 if err4 != nil && err6 != nil { 202 log.Printf("[ERR] mdns: Failed to bind to udp port: %v %v", err4, err6) 203 } 204 205 if mconn4 == nil && mconn6 == nil { 206 return nil, fmt.Errorf("failed to bind to any multicast udp port") 207 } 208 209 if mconn4 == nil { 210 mconn4 = &net.UDPConn{} 211 } 212 213 if mconn6 == nil { 214 mconn6 = &net.UDPConn{} 215 } 216 217 p1 := ipv4.NewPacketConn(mconn4) 218 p2 := ipv6.NewPacketConn(mconn6) 219 p1.SetMulticastLoopback(true) 220 p2.SetMulticastLoopback(true) 221 222 ifaces, err := net.Interfaces() 223 if err != nil { 224 return nil, err 225 } 226 227 var errCount1, errCount2 int 228 229 for _, iface := range ifaces { 230 if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { 231 errCount1++ 232 } 233 if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { 234 errCount2++ 235 } 236 } 237 238 if len(ifaces) == errCount1 && len(ifaces) == errCount2 { 239 return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") 240 } 241 242 c := &client{ 243 ipv4MulticastConn: mconn4, 244 ipv6MulticastConn: mconn6, 245 ipv4UnicastConn: uconn4, 246 ipv6UnicastConn: uconn6, 247 closedCh: make(chan struct{}), 248 } 249 return c, nil 250 } 251 252 // Close is used to cleanup the client 253 func (c *client) Close() error { 254 c.closeLock.Lock() 255 defer c.closeLock.Unlock() 256 257 if c.closed { 258 return nil 259 } 260 c.closed = true 261 262 close(c.closedCh) 263 264 if c.ipv4UnicastConn != nil { 265 c.ipv4UnicastConn.Close() 266 } 267 if c.ipv6UnicastConn != nil { 268 c.ipv6UnicastConn.Close() 269 } 270 if c.ipv4MulticastConn != nil { 271 c.ipv4MulticastConn.Close() 272 } 273 if c.ipv6MulticastConn != nil { 274 c.ipv6MulticastConn.Close() 275 } 276 277 return nil 278 } 279 280 // setInterface is used to set the query interface, uses system 281 // default if not provided 282 func (c *client) setInterface(iface *net.Interface, loopback bool) error { 283 p := ipv4.NewPacketConn(c.ipv4UnicastConn) 284 if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { 285 return err 286 } 287 p2 := ipv6.NewPacketConn(c.ipv6UnicastConn) 288 if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { 289 return err 290 } 291 p = ipv4.NewPacketConn(c.ipv4MulticastConn) 292 if err := p.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { 293 return err 294 } 295 p2 = ipv6.NewPacketConn(c.ipv6MulticastConn) 296 if err := p2.JoinGroup(iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { 297 return err 298 } 299 300 if loopback { 301 p.SetMulticastLoopback(true) 302 p2.SetMulticastLoopback(true) 303 } 304 305 return nil 306 } 307 308 // query is used to perform a lookup and stream results 309 func (c *client) query(params *QueryParam) error { 310 // Create the service name 311 serviceAddr := fmt.Sprintf("%s.%s.", trimDot(params.Service), trimDot(params.Domain)) 312 313 // Start listening for response packets 314 msgCh := make(chan *dns.Msg, 32) 315 go c.recv(c.ipv4UnicastConn, msgCh) 316 go c.recv(c.ipv6UnicastConn, msgCh) 317 go c.recv(c.ipv4MulticastConn, msgCh) 318 go c.recv(c.ipv6MulticastConn, msgCh) 319 320 // Send the query 321 m := new(dns.Msg) 322 if params.Type == dns.TypeNone { 323 m.SetQuestion(serviceAddr, dns.TypePTR) 324 } else { 325 m.SetQuestion(serviceAddr, params.Type) 326 } 327 // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question 328 // Section 329 // 330 // In the Question Section of a Multicast DNS query, the top bit of the qclass 331 // field is used to indicate that unicast responses are preferred for this 332 // particular question. (See Section 5.4.) 333 if params.WantUnicastResponse { 334 m.Question[0].Qclass |= 1 << 15 335 } 336 m.RecursionDesired = false 337 if err := c.sendQuery(m); err != nil { 338 return err 339 } 340 341 // Map the in-progress responses 342 inprogress := make(map[string]*ServiceEntry) 343 344 for { 345 select { 346 case resp := <-msgCh: 347 inp := messageToEntry(resp, inprogress) 348 349 if inp == nil { 350 continue 351 } 352 if len(resp.Question) == 0 || resp.Question[0].Name != m.Question[0].Name { 353 // discard anything which we've not asked for 354 continue 355 } 356 357 // Check if this entry is complete 358 if inp.complete() { 359 if inp.sent { 360 continue 361 } 362 363 inp.sent = true 364 select { 365 case params.Entries <- inp: 366 case <-params.Context.Done(): 367 return nil 368 } 369 } else { 370 // Fire off a node specific query 371 m := new(dns.Msg) 372 m.SetQuestion(inp.Name, inp.Type) 373 m.RecursionDesired = false 374 if err := c.sendQuery(m); err != nil { 375 log.Printf("[ERR] mdns: Failed to query instance %s: %v", inp.Name, err) 376 } 377 } 378 case <-params.Context.Done(): 379 return nil 380 } 381 } 382 } 383 384 // sendQuery is used to multicast a query out 385 func (c *client) sendQuery(q *dns.Msg) error { 386 buf, err := q.Pack() 387 if err != nil { 388 return err 389 } 390 if c.ipv4UnicastConn != nil { 391 c.ipv4UnicastConn.WriteToUDP(buf, ipv4Addr) 392 } 393 if c.ipv6UnicastConn != nil { 394 c.ipv6UnicastConn.WriteToUDP(buf, ipv6Addr) 395 } 396 return nil 397 } 398 399 // recv is used to receive until we get a shutdown 400 func (c *client) recv(l *net.UDPConn, msgCh chan *dns.Msg) { 401 if l == nil { 402 return 403 } 404 buf := make([]byte, 65536) 405 for { 406 c.closeLock.Lock() 407 if c.closed { 408 c.closeLock.Unlock() 409 return 410 } 411 c.closeLock.Unlock() 412 n, err := l.Read(buf) 413 if err != nil { 414 continue 415 } 416 msg := new(dns.Msg) 417 if err := msg.Unpack(buf[:n]); err != nil { 418 continue 419 } 420 select { 421 case msgCh <- msg: 422 case <-c.closedCh: 423 return 424 } 425 } 426 } 427 428 // ensureName is used to ensure the named node is in progress 429 func ensureName(inprogress map[string]*ServiceEntry, name string, typ uint16) *ServiceEntry { 430 if inp, ok := inprogress[name]; ok { 431 return inp 432 } 433 inp := &ServiceEntry{ 434 Name: name, 435 Type: typ, 436 } 437 inprogress[name] = inp 438 return inp 439 } 440 441 // alias is used to setup an alias between two entries 442 func alias(inprogress map[string]*ServiceEntry, src, dst string, typ uint16) { 443 srcEntry := ensureName(inprogress, src, typ) 444 inprogress[dst] = srcEntry 445 } 446 447 func messageToEntry(m *dns.Msg, inprogress map[string]*ServiceEntry) *ServiceEntry { 448 var inp *ServiceEntry 449 450 for _, answer := range append(m.Answer, m.Extra...) { 451 // TODO(reddaly): Check that response corresponds to serviceAddr? 452 switch rr := answer.(type) { 453 case *dns.PTR: 454 // Create new entry for this 455 inp = ensureName(inprogress, rr.Ptr, rr.Hdr.Rrtype) 456 if inp.complete() { 457 continue 458 } 459 case *dns.SRV: 460 // Check for a target mismatch 461 if rr.Target != rr.Hdr.Name { 462 alias(inprogress, rr.Hdr.Name, rr.Target, rr.Hdr.Rrtype) 463 } 464 465 // Get the port 466 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 467 if inp.complete() { 468 continue 469 } 470 inp.Host = rr.Target 471 inp.Port = int(rr.Port) 472 case *dns.TXT: 473 // Pull out the txt 474 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 475 if inp.complete() { 476 continue 477 } 478 inp.Info = strings.Join(rr.Txt, "|") 479 inp.InfoFields = rr.Txt 480 inp.hasTXT = true 481 case *dns.A: 482 // Pull out the IP 483 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 484 if inp.complete() { 485 continue 486 } 487 inp.Addr = rr.A // @Deprecated 488 inp.AddrV4 = rr.A 489 case *dns.AAAA: 490 // Pull out the IP 491 inp = ensureName(inprogress, rr.Hdr.Name, rr.Hdr.Rrtype) 492 if inp.complete() { 493 continue 494 } 495 inp.Addr = rr.AAAA // @Deprecated 496 inp.AddrV6 = rr.AAAA 497 } 498 499 if inp != nil { 500 inp.TTL = int(answer.Header().Ttl) 501 } 502 } 503 504 return inp 505 }