github.com/volts-dev/volts@v0.0.0-20240120094013-5e9c65924106/internal/mdns/server.go (about) 1 package mdns 2 3 import ( 4 "fmt" 5 "math/rand" 6 "net" 7 "sync" 8 "sync/atomic" 9 "time" 10 11 "github.com/miekg/dns" 12 log "github.com/volts-dev/volts/logger" 13 "golang.org/x/net/ipv4" 14 "golang.org/x/net/ipv6" 15 ) 16 17 var ( 18 mdnsGroupIPv4 = net.ParseIP("224.0.0.251") 19 mdnsGroupIPv6 = net.ParseIP("ff02::fb") 20 21 // mDNS wildcard addresses 22 mdnsWildcardAddrIPv4 = &net.UDPAddr{ 23 IP: net.ParseIP("224.0.0.0"), 24 Port: 5353, 25 } 26 mdnsWildcardAddrIPv6 = &net.UDPAddr{ 27 IP: net.ParseIP("ff02::"), 28 Port: 5353, 29 } 30 31 // mDNS endpoint addresses 32 ipv4Addr = &net.UDPAddr{ 33 IP: mdnsGroupIPv4, 34 Port: 5353, 35 } 36 ipv6Addr = &net.UDPAddr{ 37 IP: mdnsGroupIPv6, 38 Port: 5353, 39 } 40 ) 41 42 // GetMachineIP is a func which returns the outbound IP of this machine. 43 // Used by the server to determine whether to attempt send the response on a local address 44 type GetMachineIP func() net.IP 45 46 // Config is used to configure the mDNS server 47 type Config struct { 48 // Zone must be provided to support responding to queries 49 Zone Zone 50 51 // Iface if provided binds the multicast listener to the given 52 // interface. If not provided, the system default multicase interface 53 // is used. 54 Iface *net.Interface 55 56 // Port If it is not 0, replace the port 5353 with this port number. 57 Port int 58 59 // GetMachineIP is a function to return the IP of the local machine 60 GetMachineIP GetMachineIP 61 // LocalhostChecking if enabled asks the server to also send responses to 0.0.0.0 if the target IP 62 // is this host (as defined by GetMachineIP). Useful in case machine is on a VPN which blocks comms on non standard ports 63 LocalhostChecking bool 64 } 65 66 // Server is an mDNS server used to listen for mDNS queries and respond if we 67 // have a matching local record 68 type Server struct { 69 config *Config 70 71 ipv4List *net.UDPConn 72 ipv6List *net.UDPConn 73 74 shutdown bool 75 shutdownCh chan struct{} 76 shutdownLock sync.Mutex 77 wg sync.WaitGroup 78 79 outboundIP net.IP 80 } 81 82 // NewServer is used to create a new mDNS server from a config 83 func NewServer(config *Config) (*Server, error) { 84 setCustomPort(config.Port) 85 86 // Create the listeners 87 // Create wildcard connections (because :5353 can be already taken by other apps) 88 ipv4List, _ := net.ListenUDP("udp4", mdnsWildcardAddrIPv4) 89 ipv6List, _ := net.ListenUDP("udp6", mdnsWildcardAddrIPv6) 90 if ipv4List == nil && ipv6List == nil { 91 return nil, fmt.Errorf("[ERR] mdns: Failed to bind to any udp port!") 92 } 93 94 if ipv4List == nil { 95 ipv4List = &net.UDPConn{} 96 } 97 if ipv6List == nil { 98 ipv6List = &net.UDPConn{} 99 } 100 101 // Join multicast groups to receive announcements 102 p1 := ipv4.NewPacketConn(ipv4List) 103 p2 := ipv6.NewPacketConn(ipv6List) 104 p1.SetMulticastLoopback(true) 105 p2.SetMulticastLoopback(true) 106 107 if config.Iface != nil { 108 if err := p1.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { 109 return nil, err 110 } 111 if err := p2.JoinGroup(config.Iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { 112 return nil, err 113 } 114 } else { 115 ifaces, err := net.Interfaces() 116 if err != nil { 117 return nil, err 118 } 119 errCount1, errCount2 := 0, 0 120 for _, iface := range ifaces { 121 if err := p1.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv4}); err != nil { 122 errCount1++ 123 } 124 if err := p2.JoinGroup(&iface, &net.UDPAddr{IP: mdnsGroupIPv6}); err != nil { 125 errCount2++ 126 } 127 } 128 if len(ifaces) == errCount1 && len(ifaces) == errCount2 { 129 return nil, fmt.Errorf("Failed to join multicast group on all interfaces!") 130 } 131 } 132 133 ipFunc := getOutboundIP 134 if config.GetMachineIP != nil { 135 ipFunc = config.GetMachineIP 136 } 137 138 s := &Server{ 139 config: config, 140 ipv4List: ipv4List, 141 ipv6List: ipv6List, 142 shutdownCh: make(chan struct{}), 143 outboundIP: ipFunc(), 144 } 145 146 go s.recv(s.ipv4List) 147 go s.recv(s.ipv6List) 148 149 s.wg.Add(1) 150 go s.probe() 151 152 return s, nil 153 } 154 155 // Shutdown is used to shutdown the listener 156 func (s *Server) Shutdown() error { 157 s.shutdownLock.Lock() 158 defer s.shutdownLock.Unlock() 159 160 if s.shutdown { 161 return nil 162 } 163 164 s.shutdown = true 165 close(s.shutdownCh) 166 s.unregister() 167 168 if s.ipv4List != nil { 169 s.ipv4List.Close() 170 } 171 if s.ipv6List != nil { 172 s.ipv6List.Close() 173 } 174 175 s.wg.Wait() 176 return nil 177 } 178 179 // recv is a long running routine to receive packets from an interface 180 func (s *Server) recv(c *net.UDPConn) { 181 if c == nil { 182 return 183 } 184 buf := make([]byte, 65536) 185 for { 186 s.shutdownLock.Lock() 187 if s.shutdown { 188 s.shutdownLock.Unlock() 189 return 190 } 191 s.shutdownLock.Unlock() 192 n, from, err := c.ReadFrom(buf) 193 if err != nil { 194 continue 195 } 196 if err := s.parsePacket(buf[:n], from); err != nil { 197 log.Errf("[ERR] mdns: Failed to handle query: %v", err) 198 } 199 } 200 } 201 202 // parsePacket is used to parse an incoming packet 203 func (s *Server) parsePacket(packet []byte, from net.Addr) error { 204 var msg dns.Msg 205 if err := msg.Unpack(packet); err != nil { 206 log.Errf("[ERR] mdns: Failed to unpack packet: %v", err) 207 return err 208 } 209 // TODO: This is a bit of a hack 210 // We decided to ignore some mDNS answers for the time being 211 // See: https://tools.ietf.org/html/rfc6762#section-7.2 212 msg.Truncated = false 213 return s.handleQuery(&msg, from) 214 } 215 216 // handleQuery is used to handle an incoming query 217 func (s *Server) handleQuery(query *dns.Msg, from net.Addr) error { 218 if query.Opcode != dns.OpcodeQuery { 219 // "In both multicast query and multicast response messages, the OPCODE MUST 220 // be zero on transmission (only standard queries are currently supported 221 // over multicast). Multicast DNS messages received with an OPCODE other 222 // than zero MUST be silently ignored." Note: OpcodeQuery == 0 223 return fmt.Errorf("mdns: received query with non-zero Opcode %v: %v", query.Opcode, *query) 224 } 225 if query.Rcode != 0 { 226 // "In both multicast query and multicast response messages, the Response 227 // Code MUST be zero on transmission. Multicast DNS messages received with 228 // non-zero Response Codes MUST be silently ignored." 229 return fmt.Errorf("mdns: received query with non-zero Rcode %v: %v", query.Rcode, *query) 230 } 231 232 // TODO(reddaly): Handle "TC (Truncated) Bit": 233 // In query messages, if the TC bit is set, it means that additional 234 // Known-Answer records may be following shortly. A responder SHOULD 235 // record this fact, and wait for those additional Known-Answer records, 236 // before deciding whether to respond. If the TC bit is clear, it means 237 // that the querying host has no additional Known Answers. 238 if query.Truncated { 239 return fmt.Errorf("[ERR] mdns: support for DNS requests with high truncated bit not implemented: %v", *query) 240 } 241 242 var unicastAnswer, multicastAnswer []dns.RR 243 244 // Handle each question 245 for _, q := range query.Question { 246 mrecs, urecs := s.handleQuestion(q) 247 multicastAnswer = append(multicastAnswer, mrecs...) 248 unicastAnswer = append(unicastAnswer, urecs...) 249 } 250 251 // See section 18 of RFC 6762 for rules about DNS headers. 252 resp := func(unicast bool) *dns.Msg { 253 // 18.1: ID (Query Identifier) 254 // 0 for multicast response, query.Id for unicast response 255 id := uint16(0) 256 if unicast { 257 id = query.Id 258 } 259 260 var answer []dns.RR 261 if unicast { 262 answer = unicastAnswer 263 } else { 264 answer = multicastAnswer 265 } 266 if len(answer) == 0 { 267 return nil 268 } 269 270 return &dns.Msg{ 271 MsgHdr: dns.MsgHdr{ 272 Id: id, 273 274 // 18.2: QR (Query/Response) Bit - must be set to 1 in response. 275 Response: true, 276 277 // 18.3: OPCODE - must be zero in response (OpcodeQuery == 0) 278 Opcode: dns.OpcodeQuery, 279 280 // 18.4: AA (Authoritative Answer) Bit - must be set to 1 281 Authoritative: true, 282 283 // The following fields must all be set to 0: 284 // 18.5: TC (TRUNCATED) Bit 285 // 18.6: RD (Recursion Desired) Bit 286 // 18.7: RA (Recursion Available) Bit 287 // 18.8: Z (Zero) Bit 288 // 18.9: AD (Authentic Data) Bit 289 // 18.10: CD (Checking Disabled) Bit 290 // 18.11: RCODE (Response Code) 291 }, 292 // 18.12 pertains to questions (handled by handleQuestion) 293 // 18.13 pertains to resource records (handled by handleQuestion) 294 295 // 18.14: Name Compression - responses should be compressed (though see 296 // caveats in the RFC), so set the Compress bit (part of the dns library 297 // API, not part of the DNS packet) to true. 298 Compress: true, 299 Question: query.Question, 300 Answer: answer, 301 } 302 } 303 304 if mresp := resp(false); mresp != nil { 305 if err := s.sendResponse(mresp, from); err != nil { 306 return fmt.Errorf("mdns: error sending multicast response: %v", err) 307 } 308 } 309 if uresp := resp(true); uresp != nil { 310 if err := s.sendResponse(uresp, from); err != nil { 311 return fmt.Errorf("mdns: error sending unicast response: %v", err) 312 } 313 } 314 return nil 315 } 316 317 // handleQuestion is used to handle an incoming question 318 // 319 // The response to a question may be transmitted over multicast, unicast, or 320 // both. The return values are DNS records for each transmission type. 321 func (s *Server) handleQuestion(q dns.Question) (multicastRecs, unicastRecs []dns.RR) { 322 records := s.config.Zone.Records(q) 323 if len(records) == 0 { 324 return nil, nil 325 } 326 327 // Handle unicast and multicast responses. 328 // TODO(reddaly): The decision about sending over unicast vs. multicast is not 329 // yet fully compliant with RFC 6762. For example, the unicast bit should be 330 // ignored if the records in question are close to TTL expiration. For now, 331 // we just use the unicast bit to make the decision, as per the spec: 332 // RFC 6762, section 18.12. Repurposing of Top Bit of qclass in Question 333 // Section 334 // 335 // In the Question Section of a Multicast DNS query, the top bit of the 336 // qclass field is used to indicate that unicast responses are preferred 337 // for this particular question. (See Section 5.4.) 338 if q.Qclass&(1<<15) != 0 { 339 return nil, records 340 } 341 return records, nil 342 } 343 344 func (s *Server) probe() { 345 defer s.wg.Done() 346 347 sd, ok := s.config.Zone.(*MDNSService) 348 if !ok { 349 return 350 } 351 352 name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain)) 353 354 q := new(dns.Msg) 355 q.SetQuestion(name, dns.TypePTR) 356 q.RecursionDesired = false 357 358 srv := &dns.SRV{ 359 Hdr: dns.RR_Header{ 360 Name: name, 361 Rrtype: dns.TypeSRV, 362 Class: dns.ClassINET, 363 Ttl: defaultTTL, 364 }, 365 Priority: 0, 366 Weight: 0, 367 Port: uint16(sd.Port), 368 Target: sd.HostName, 369 } 370 txt := &dns.TXT{ 371 Hdr: dns.RR_Header{ 372 Name: name, 373 Rrtype: dns.TypeTXT, 374 Class: dns.ClassINET, 375 Ttl: defaultTTL, 376 }, 377 Txt: sd.TXT, 378 } 379 q.Ns = []dns.RR{srv, txt} 380 381 randomizer := rand.New(rand.NewSource(time.Now().UnixNano())) 382 383 for i := 0; i < 3; i++ { 384 if err := s.SendMulticast(q); err != nil { 385 log.Errf("[ERR] mdns: failed to send probe:", err.Error()) 386 } 387 time.Sleep(time.Duration(randomizer.Intn(250)) * time.Millisecond) 388 } 389 390 resp := new(dns.Msg) 391 resp.MsgHdr.Response = true 392 393 // set for query 394 q.SetQuestion(name, dns.TypeANY) 395 396 resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...) 397 398 // reset 399 q.SetQuestion(name, dns.TypePTR) 400 401 // From RFC6762 402 // The Multicast DNS responder MUST send at least two unsolicited 403 // responses, one second apart. To provide increased robustness against 404 // packet loss, a responder MAY send up to eight unsolicited responses, 405 // provided that the interval between unsolicited responses increases by 406 // at least a factor of two with every response sent. 407 timeout := 1 * time.Second 408 timer := time.NewTimer(timeout) 409 for i := 0; i < 3; i++ { 410 if err := s.SendMulticast(resp); err != nil { 411 log.Errf("[ERR] mdns: failed to send announcement:", err.Error()) 412 } 413 select { 414 case <-timer.C: 415 timeout *= 2 416 timer.Reset(timeout) 417 case <-s.shutdownCh: 418 timer.Stop() 419 return 420 } 421 } 422 } 423 424 // SendMulticast us used to send a multicast response packet 425 func (s *Server) SendMulticast(msg *dns.Msg) error { 426 buf, err := msg.Pack() 427 if err != nil { 428 return err 429 } 430 if s.ipv4List != nil { 431 s.ipv4List.WriteToUDP(buf, ipv4Addr) 432 } 433 if s.ipv6List != nil { 434 s.ipv6List.WriteToUDP(buf, ipv6Addr) 435 } 436 return nil 437 } 438 439 // sendResponse is used to send a response packet 440 func (s *Server) sendResponse(resp *dns.Msg, from net.Addr) error { 441 // TODO(reddaly): Respect the unicast argument, and allow sending responses 442 // over multicast. 443 buf, err := resp.Pack() 444 if err != nil { 445 return err 446 } 447 448 // Determine the socket to send from 449 addr := from.(*net.UDPAddr) 450 conn := s.ipv4List 451 backupTarget := net.IPv4zero 452 453 if addr.IP.To4() == nil { 454 conn = s.ipv6List 455 backupTarget = net.IPv6zero 456 } 457 _, err = conn.WriteToUDP(buf, addr) 458 // If the address we're responding to is this machine then we can also attempt sending on 0.0.0.0 459 // This covers the case where this machine is using a VPN and certain ports are blocked so the response never gets there 460 // Sending two responses is OK 461 if s.config.LocalhostChecking && addr.IP.Equal(s.outboundIP) { 462 // ignore any errors, this is best efforts 463 conn.WriteToUDP(buf, &net.UDPAddr{IP: backupTarget, Port: addr.Port}) 464 } 465 return err 466 467 } 468 469 func (s *Server) unregister() error { 470 sd, ok := s.config.Zone.(*MDNSService) 471 if !ok { 472 return nil 473 } 474 475 atomic.StoreUint32(&sd.TTL, 0) 476 name := fmt.Sprintf("%s.%s.%s.", sd.Instance, trimDot(sd.Service), trimDot(sd.Domain)) 477 478 q := new(dns.Msg) 479 q.SetQuestion(name, dns.TypeANY) 480 481 resp := new(dns.Msg) 482 resp.MsgHdr.Response = true 483 resp.Answer = append(resp.Answer, s.config.Zone.Records(q.Question[0])...) 484 485 return s.SendMulticast(resp) 486 } 487 488 func setCustomPort(port int) { 489 if port != 0 { 490 if mdnsWildcardAddrIPv4.Port != port { 491 mdnsWildcardAddrIPv4.Port = port 492 } 493 if mdnsWildcardAddrIPv6.Port != port { 494 mdnsWildcardAddrIPv6.Port = port 495 } 496 if ipv4Addr.Port != port { 497 ipv4Addr.Port = port 498 } 499 if ipv6Addr.Port != port { 500 ipv6Addr.Port = port 501 } 502 } 503 } 504 505 // getOutboundIP returns the IP address of this machine as seen when dialling out 506 func getOutboundIP() net.IP { 507 conn, err := net.Dial("udp", "8.8.8.8:80") 508 if err != nil { 509 // no net connectivity maybe so fallback 510 return nil 511 } 512 defer conn.Close() 513 514 localAddr := conn.LocalAddr().(*net.UDPAddr) 515 516 return localAddr.IP 517 }