github.com/slackhq/nebula@v1.9.0/dns_server.go (about) 1 package nebula 2 3 import ( 4 "fmt" 5 "net" 6 "strconv" 7 "strings" 8 "sync" 9 10 "github.com/miekg/dns" 11 "github.com/sirupsen/logrus" 12 "github.com/slackhq/nebula/config" 13 "github.com/slackhq/nebula/iputil" 14 ) 15 16 // This whole thing should be rewritten to use context 17 18 var dnsR *dnsRecords 19 var dnsServer *dns.Server 20 var dnsAddr string 21 22 type dnsRecords struct { 23 sync.RWMutex 24 dnsMap map[string]string 25 hostMap *HostMap 26 } 27 28 func newDnsRecords(hostMap *HostMap) *dnsRecords { 29 return &dnsRecords{ 30 dnsMap: make(map[string]string), 31 hostMap: hostMap, 32 } 33 } 34 35 func (d *dnsRecords) Query(data string) string { 36 d.RLock() 37 defer d.RUnlock() 38 if r, ok := d.dnsMap[strings.ToLower(data)]; ok { 39 return r 40 } 41 return "" 42 } 43 44 func (d *dnsRecords) QueryCert(data string) string { 45 ip := net.ParseIP(data[:len(data)-1]) 46 if ip == nil { 47 return "" 48 } 49 iip := iputil.Ip2VpnIp(ip) 50 hostinfo := d.hostMap.QueryVpnIp(iip) 51 if hostinfo == nil { 52 return "" 53 } 54 q := hostinfo.GetCert() 55 if q == nil { 56 return "" 57 } 58 cert := q.Details 59 c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer) 60 return c 61 } 62 63 func (d *dnsRecords) Add(host, data string) { 64 d.Lock() 65 defer d.Unlock() 66 d.dnsMap[strings.ToLower(host)] = data 67 } 68 69 func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) { 70 for _, q := range m.Question { 71 switch q.Qtype { 72 case dns.TypeA: 73 l.Debugf("Query for A %s", q.Name) 74 ip := dnsR.Query(q.Name) 75 if ip != "" { 76 rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip)) 77 if err == nil { 78 m.Answer = append(m.Answer, rr) 79 } 80 } 81 case dns.TypeTXT: 82 a, _, _ := net.SplitHostPort(w.RemoteAddr().String()) 83 b := net.ParseIP(a) 84 // We don't answer these queries from non nebula nodes or localhost 85 //l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR) 86 if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" { 87 return 88 } 89 l.Debugf("Query for TXT %s", q.Name) 90 ip := dnsR.QueryCert(q.Name) 91 if ip != "" { 92 rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip)) 93 if err == nil { 94 m.Answer = append(m.Answer, rr) 95 } 96 } 97 } 98 } 99 100 if len(m.Answer) == 0 { 101 m.Rcode = dns.RcodeNameError 102 } 103 } 104 105 func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) { 106 m := new(dns.Msg) 107 m.SetReply(r) 108 m.Compress = false 109 110 switch r.Opcode { 111 case dns.OpcodeQuery: 112 parseQuery(l, m, w) 113 } 114 115 w.WriteMsg(m) 116 } 117 118 func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() { 119 dnsR = newDnsRecords(hostMap) 120 121 // attach request handler func 122 dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) { 123 handleDnsRequest(l, w, r) 124 }) 125 126 c.RegisterReloadCallback(func(c *config.C) { 127 reloadDns(l, c) 128 }) 129 130 return func() { 131 startDns(l, c) 132 } 133 } 134 135 func getDnsServerAddr(c *config.C) string { 136 dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", "")) 137 // Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve. 138 if dnsHost == "[::]" { 139 dnsHost = "::" 140 } 141 return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53))) 142 } 143 144 func startDns(l *logrus.Logger, c *config.C) { 145 dnsAddr = getDnsServerAddr(c) 146 dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"} 147 l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder") 148 err := dnsServer.ListenAndServe() 149 defer dnsServer.Shutdown() 150 if err != nil { 151 l.Errorf("Failed to start server: %s\n ", err.Error()) 152 } 153 } 154 155 func reloadDns(l *logrus.Logger, c *config.C) { 156 if dnsAddr == getDnsServerAddr(c) { 157 l.Debug("No DNS server config change detected") 158 return 159 } 160 161 l.Debug("Restarting DNS server") 162 dnsServer.Shutdown() 163 go startDns(l, c) 164 }