github.com/noisysockets/noisysockets@v0.21.2-0.20240515114641-7f467e651c90/internal/dns/udp_resolver.go (about)

     1  // SPDX-License-Identifier: MPL-2.0
     2  /*
     3   * Copyright (C) 2024 The Noisy Sockets Authors.
     4   *
     5   * This Source Code Form is subject to the terms of the Mozilla Public
     6   * License, v. 2.0. If a copy of the MPL was not distributed with this
     7   * file, You can obtain one at http://mozilla.org/MPL/2.0/.
     8   */
     9  
    10  package dns
    11  
    12  import (
    13  	"context"
    14  	"errors"
    15  	"fmt"
    16  	stdnet "net"
    17  	"net/netip"
    18  	"time"
    19  
    20  	"github.com/miekg/dns"
    21  	"github.com/noisysockets/noisysockets/internal/dns/addrselect"
    22  	"github.com/noisysockets/noisysockets/internal/util"
    23  	"github.com/noisysockets/noisysockets/network"
    24  	"github.com/noisysockets/noisysockets/networkutil"
    25  )
    26  
    27  // udpResolver is a DNS resolver that uses DNS over UDP.
    28  type udpResolver struct {
    29  	net         network.Network
    30  	nameservers []netip.AddrPort
    31  }
    32  
    33  // NewUDPResolver creates a new DNS resolver that uses DNS over UDP.
    34  func NewUDPResolver(net network.Network, nameservers []netip.AddrPort) Resolver {
    35  	// Use the default DNS port if none is specified.
    36  	for i, ns := range nameservers {
    37  		if ns.Port() == 0 {
    38  			nameservers[i] = netip.AddrPortFrom(ns.Addr(), 53)
    39  		}
    40  	}
    41  
    42  	return &udpResolver{
    43  		net:         net,
    44  		nameservers: nameservers,
    45  	}
    46  }
    47  
    48  func (r *udpResolver) LookupHost(host string) ([]netip.Addr, error) {
    49  	client := &dns.Client{
    50  		Net:     "udp",
    51  		Timeout: 10 * time.Second,
    52  	}
    53  
    54  	interfaceAddrs, err := r.net.InterfaceAddrs()
    55  	if err != nil {
    56  		return nil, fmt.Errorf("could not get interface addresses: %w", err)
    57  	}
    58  
    59  	netipAddrs, ok := networkutil.ToNetIPAddrs(interfaceAddrs)
    60  	if !ok {
    61  		return nil, fmt.Errorf("could not convert interface addresses")
    62  	}
    63  
    64  	var queryTypes []uint16
    65  	if networkutil.HasIPv4(netipAddrs) {
    66  		queryTypes = append(queryTypes, dns.TypeA)
    67  	}
    68  	if networkutil.HasIPv6(netipAddrs) {
    69  		queryTypes = append(queryTypes, dns.TypeAAAA)
    70  	}
    71  
    72  	// Shuffle the nameserver list for load balancing.
    73  	shuffledNameservers := make([]netip.AddrPort, len(r.nameservers))
    74  	copy(shuffledNameservers, r.nameservers)
    75  	shuffledNameservers = util.Shuffle(shuffledNameservers)
    76  
    77  	var addrs []netip.Addr
    78  	var queryErr error
    79  
    80  	for _, ns := range shuffledNameservers {
    81  		for _, queryType := range queryTypes {
    82  			in, err := r.queryNameserver(client, ns, queryType, host)
    83  			if err != nil {
    84  				queryErr = errors.Join(queryErr, err)
    85  				continue
    86  			}
    87  
    88  			for _, rr := range in.Answer {
    89  				switch rr := rr.(type) {
    90  				case *dns.A:
    91  					addrs = append(addrs, netip.AddrFrom4([4]byte(rr.A.To4())))
    92  				case *dns.AAAA:
    93  					addrs = append(addrs, netip.AddrFrom16([16]byte(rr.AAAA.To16())))
    94  				}
    95  			}
    96  		}
    97  
    98  		if len(addrs) > 0 {
    99  			addrselect.SortByRFC6724(r.net, addrs)
   100  			return addrs, nil
   101  		}
   102  	}
   103  
   104  	if queryErr != nil {
   105  		return nil, &stdnet.DNSError{Err: queryErr.Error(), Name: host}
   106  	}
   107  
   108  	return nil, &stdnet.DNSError{Err: "no such host", Name: host}
   109  }
   110  
   111  func (r *udpResolver) queryNameserver(client *dns.Client, nameserver netip.AddrPort, queryType uint16, host string) (*dns.Msg, error) {
   112  	ctx, cancel := context.WithTimeout(context.Background(), client.Timeout)
   113  	defer cancel()
   114  
   115  	conn, err := r.net.DialContext(ctx, client.Net, nameserver.String())
   116  	if err != nil {
   117  		return nil, &stdnet.DNSError{
   118  			Err:  fmt.Errorf("could not connect to DNS server %s: %w", nameserver, err).Error(),
   119  			Name: host,
   120  		}
   121  	}
   122  	defer conn.Close()
   123  
   124  	req := new(dns.Msg)
   125  	req.SetQuestion(dns.Fqdn(host), queryType)
   126  
   127  	reply, _, err := client.ExchangeWithConn(req, &dns.Conn{Conn: conn})
   128  	if err != nil {
   129  		return nil, &stdnet.DNSError{
   130  			Err:  fmt.Errorf("could not query DNS server %s: %w", nameserver.String(), err).Error(),
   131  			Name: host,
   132  		}
   133  	}
   134  
   135  	return reply, nil
   136  }