github.com/slackhq/nebula@v1.9.0/dns_server.go (about)

     1  package nebula
     2  
     3  import (
     4  	"fmt"
     5  	"net"
     6  	"strconv"
     7  	"strings"
     8  	"sync"
     9  
    10  	"github.com/miekg/dns"
    11  	"github.com/sirupsen/logrus"
    12  	"github.com/slackhq/nebula/config"
    13  	"github.com/slackhq/nebula/iputil"
    14  )
    15  
    16  // This whole thing should be rewritten to use context
    17  
    18  var dnsR *dnsRecords
    19  var dnsServer *dns.Server
    20  var dnsAddr string
    21  
    22  type dnsRecords struct {
    23  	sync.RWMutex
    24  	dnsMap  map[string]string
    25  	hostMap *HostMap
    26  }
    27  
    28  func newDnsRecords(hostMap *HostMap) *dnsRecords {
    29  	return &dnsRecords{
    30  		dnsMap:  make(map[string]string),
    31  		hostMap: hostMap,
    32  	}
    33  }
    34  
    35  func (d *dnsRecords) Query(data string) string {
    36  	d.RLock()
    37  	defer d.RUnlock()
    38  	if r, ok := d.dnsMap[strings.ToLower(data)]; ok {
    39  		return r
    40  	}
    41  	return ""
    42  }
    43  
    44  func (d *dnsRecords) QueryCert(data string) string {
    45  	ip := net.ParseIP(data[:len(data)-1])
    46  	if ip == nil {
    47  		return ""
    48  	}
    49  	iip := iputil.Ip2VpnIp(ip)
    50  	hostinfo := d.hostMap.QueryVpnIp(iip)
    51  	if hostinfo == nil {
    52  		return ""
    53  	}
    54  	q := hostinfo.GetCert()
    55  	if q == nil {
    56  		return ""
    57  	}
    58  	cert := q.Details
    59  	c := fmt.Sprintf("\"Name: %s\" \"Ips: %s\" \"Subnets %s\" \"Groups %s\" \"NotBefore %s\" \"NotAfter %s\" \"PublicKey %x\" \"IsCA %t\" \"Issuer %s\"", cert.Name, cert.Ips, cert.Subnets, cert.Groups, cert.NotBefore, cert.NotAfter, cert.PublicKey, cert.IsCA, cert.Issuer)
    60  	return c
    61  }
    62  
    63  func (d *dnsRecords) Add(host, data string) {
    64  	d.Lock()
    65  	defer d.Unlock()
    66  	d.dnsMap[strings.ToLower(host)] = data
    67  }
    68  
    69  func parseQuery(l *logrus.Logger, m *dns.Msg, w dns.ResponseWriter) {
    70  	for _, q := range m.Question {
    71  		switch q.Qtype {
    72  		case dns.TypeA:
    73  			l.Debugf("Query for A %s", q.Name)
    74  			ip := dnsR.Query(q.Name)
    75  			if ip != "" {
    76  				rr, err := dns.NewRR(fmt.Sprintf("%s A %s", q.Name, ip))
    77  				if err == nil {
    78  					m.Answer = append(m.Answer, rr)
    79  				}
    80  			}
    81  		case dns.TypeTXT:
    82  			a, _, _ := net.SplitHostPort(w.RemoteAddr().String())
    83  			b := net.ParseIP(a)
    84  			// We don't answer these queries from non nebula nodes or localhost
    85  			//l.Debugf("Does %s contain %s", b, dnsR.hostMap.vpnCIDR)
    86  			if !dnsR.hostMap.vpnCIDR.Contains(b) && a != "127.0.0.1" {
    87  				return
    88  			}
    89  			l.Debugf("Query for TXT %s", q.Name)
    90  			ip := dnsR.QueryCert(q.Name)
    91  			if ip != "" {
    92  				rr, err := dns.NewRR(fmt.Sprintf("%s TXT %s", q.Name, ip))
    93  				if err == nil {
    94  					m.Answer = append(m.Answer, rr)
    95  				}
    96  			}
    97  		}
    98  	}
    99  
   100  	if len(m.Answer) == 0 {
   101  		m.Rcode = dns.RcodeNameError
   102  	}
   103  }
   104  
   105  func handleDnsRequest(l *logrus.Logger, w dns.ResponseWriter, r *dns.Msg) {
   106  	m := new(dns.Msg)
   107  	m.SetReply(r)
   108  	m.Compress = false
   109  
   110  	switch r.Opcode {
   111  	case dns.OpcodeQuery:
   112  		parseQuery(l, m, w)
   113  	}
   114  
   115  	w.WriteMsg(m)
   116  }
   117  
   118  func dnsMain(l *logrus.Logger, hostMap *HostMap, c *config.C) func() {
   119  	dnsR = newDnsRecords(hostMap)
   120  
   121  	// attach request handler func
   122  	dns.HandleFunc(".", func(w dns.ResponseWriter, r *dns.Msg) {
   123  		handleDnsRequest(l, w, r)
   124  	})
   125  
   126  	c.RegisterReloadCallback(func(c *config.C) {
   127  		reloadDns(l, c)
   128  	})
   129  
   130  	return func() {
   131  		startDns(l, c)
   132  	}
   133  }
   134  
   135  func getDnsServerAddr(c *config.C) string {
   136  	dnsHost := strings.TrimSpace(c.GetString("lighthouse.dns.host", ""))
   137  	// Old guidance was to provide the literal `[::]` in `lighthouse.dns.host` but that won't resolve.
   138  	if dnsHost == "[::]" {
   139  		dnsHost = "::"
   140  	}
   141  	return net.JoinHostPort(dnsHost, strconv.Itoa(c.GetInt("lighthouse.dns.port", 53)))
   142  }
   143  
   144  func startDns(l *logrus.Logger, c *config.C) {
   145  	dnsAddr = getDnsServerAddr(c)
   146  	dnsServer = &dns.Server{Addr: dnsAddr, Net: "udp"}
   147  	l.WithField("dnsListener", dnsAddr).Info("Starting DNS responder")
   148  	err := dnsServer.ListenAndServe()
   149  	defer dnsServer.Shutdown()
   150  	if err != nil {
   151  		l.Errorf("Failed to start server: %s\n ", err.Error())
   152  	}
   153  }
   154  
   155  func reloadDns(l *logrus.Logger, c *config.C) {
   156  	if dnsAddr == getDnsServerAddr(c) {
   157  		l.Debug("No DNS server config change detected")
   158  		return
   159  	}
   160  
   161  	l.Debug("Restarting DNS server")
   162  	dnsServer.Shutdown()
   163  	go startDns(l, c)
   164  }