github.com/database64128/shadowsocks-go@v1.10.2-0.20240315062903-143a773533f1/dns/dns.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"errors"
     7  	"fmt"
     8  	"io"
     9  	"net"
    10  	"net/netip"
    11  	"os"
    12  	"slices"
    13  	"sync"
    14  	"time"
    15  
    16  	"github.com/database64128/shadowsocks-go/conn"
    17  	"github.com/database64128/shadowsocks-go/zerocopy"
    18  	"go.uber.org/zap"
    19  	"golang.org/x/net/dns/dnsmessage"
    20  )
    21  
    22  const (
    23  	// maxDNSPacketSize is the maximum packet size to advertise in EDNS(0).
    24  	// We use the same value as Go itself.
    25  	maxDNSPacketSize = 1232
    26  
    27  	lookupTimeout = 20 * time.Second
    28  )
    29  
    30  var (
    31  	ErrLookup                = errors.New("name lookup failed")
    32  	ErrMessageNotResponse    = errors.New("message is not a response")
    33  	ErrDomainNoAssociatedIPs = errors.New("domain name has no associated IP addresses")
    34  )
    35  
    36  // ResolverConfig configures a DNS resolver.
    37  type ResolverConfig struct {
    38  	// Name is the resolver's name.
    39  	// The name must be unique among all resolvers.
    40  	Name string `json:"name"`
    41  
    42  	// Type is the resolver type.
    43  	//
    44  	// Available values:
    45  	// - "plain": Resolve names by sending cleartext DNS queries to the configured upstream server.
    46  	// - "system": Use the system resolver. This does not support custom server addresses or clients.
    47  	//
    48  	// The default value is "plain".
    49  	Type string `json:"type"`
    50  
    51  	// AddrPort is the upstream server's address and port.
    52  	AddrPort netip.AddrPort `json:"addrPort"`
    53  
    54  	// TCPClientName is the name of the TCPClient to use.
    55  	// Leave empty to disable TCP.
    56  	TCPClientName string `json:"tcpClientName"`
    57  
    58  	// UDPClientName is the name of the UDPClient to use.
    59  	// Leave empty to disable UDP.
    60  	UDPClientName string `json:"udpClientName"`
    61  }
    62  
    63  // SimpleResolver creates a new [SimpleResolver] from the config.
    64  func (rc *ResolverConfig) SimpleResolver(tcpClientMap map[string]zerocopy.TCPClient, udpClientMap map[string]zerocopy.UDPClient, logger *zap.Logger) (SimpleResolver, error) {
    65  	switch rc.Type {
    66  	case "plain", "":
    67  	case "system":
    68  		if rc.AddrPort.IsValid() || rc.TCPClientName != "" || rc.UDPClientName != "" {
    69  			return nil, errors.New("system resolver does not support custom server addresses or clients")
    70  		}
    71  		return NewSystemResolver(rc.Name, logger), nil
    72  	default:
    73  		return nil, fmt.Errorf("unknown resolver type: %s", rc.Type)
    74  	}
    75  
    76  	if !rc.AddrPort.IsValid() {
    77  		return nil, errors.New("missing resolver address")
    78  	}
    79  
    80  	var (
    81  		tcpClient zerocopy.TCPClient
    82  		udpClient zerocopy.UDPClient
    83  	)
    84  
    85  	if rc.TCPClientName != "" {
    86  		tcpClient = tcpClientMap[rc.TCPClientName]
    87  		if tcpClient == nil {
    88  			return nil, fmt.Errorf("unknown TCP client: %s", rc.TCPClientName)
    89  		}
    90  	}
    91  
    92  	if rc.UDPClientName != "" {
    93  		udpClient = udpClientMap[rc.UDPClientName]
    94  		if udpClient == nil {
    95  			return nil, fmt.Errorf("unknown UDP client: %s", rc.UDPClientName)
    96  		}
    97  	}
    98  
    99  	return NewResolver(rc.Name, rc.AddrPort, tcpClient, udpClient, logger), nil
   100  }
   101  
   102  // Result represents the result of name resolution.
   103  type Result struct {
   104  	IPv4 []netip.Addr
   105  	IPv6 []netip.Addr
   106  
   107  	// TTL is the minimum TTL of A and AAAA RRs.
   108  	TTL time.Time
   109  
   110  	v4done bool
   111  	v6done bool
   112  }
   113  
   114  type Resolver struct {
   115  	// name stores the resolver's name to make its log messages more useful.
   116  	name string
   117  
   118  	// mu protects the DNS cache map.
   119  	mu sync.RWMutex
   120  
   121  	// cache is the DNS cache map.
   122  	cache map[string]Result
   123  
   124  	// serverAddr is the upstream server's address and port.
   125  	serverAddr conn.Addr
   126  
   127  	// serverAddrPort is the upstream server's address and port.
   128  	serverAddrPort netip.AddrPort
   129  
   130  	// tcpClient is the TCPClient to use for sending queries and receiving replies.
   131  	tcpClient zerocopy.TCPClient
   132  
   133  	// udpClient is the UDPClient to use for sending queries and receiving replies.
   134  	udpClient zerocopy.UDPClient
   135  
   136  	// logger is the shared logger instance.
   137  	logger *zap.Logger
   138  }
   139  
   140  func NewResolver(name string, serverAddrPort netip.AddrPort, tcpClient zerocopy.TCPClient, udpClient zerocopy.UDPClient, logger *zap.Logger) *Resolver {
   141  	return &Resolver{
   142  		name:           name,
   143  		cache:          make(map[string]Result),
   144  		serverAddr:     conn.AddrFromIPPort(serverAddrPort),
   145  		serverAddrPort: serverAddrPort,
   146  		tcpClient:      tcpClient,
   147  		udpClient:      udpClient,
   148  		logger:         logger,
   149  	}
   150  }
   151  
   152  func (r *Resolver) lookup(ctx context.Context, name string) (Result, error) {
   153  	// Lookup cache first.
   154  	r.mu.RLock()
   155  	result, ok := r.cache[name]
   156  	r.mu.RUnlock()
   157  
   158  	if ok && !result.HasExpired() {
   159  		if ce := r.logger.Check(zap.DebugLevel, "DNS lookup got result from cache"); ce != nil {
   160  			ce.Write(
   161  				zap.String("resolver", r.name),
   162  				zap.String("name", name),
   163  				zap.Time("ttl", result.TTL),
   164  				zap.Stringers("v4", result.IPv4),
   165  				zap.Stringers("v6", result.IPv6),
   166  			)
   167  		}
   168  		return result, nil
   169  	}
   170  
   171  	// Send queries to upstream server.
   172  	return r.sendQueries(ctx, name)
   173  }
   174  
   175  func (r *Resolver) sendQueries(ctx context.Context, nameString string) (result Result, err error) {
   176  	name, err := dnsmessage.NewName(nameString + ".")
   177  	if err != nil {
   178  		return
   179  	}
   180  
   181  	var (
   182  		rh dnsmessage.ResourceHeader
   183  		rb dnsmessage.OPTResource
   184  	)
   185  
   186  	err = rh.SetEDNS0(maxDNSPacketSize, dnsmessage.RCodeSuccess, false)
   187  	if err != nil {
   188  		return
   189  	}
   190  
   191  	qBuf := make([]byte, 2+512+2+512)
   192  
   193  	q4 := dnsmessage.Message{
   194  		Header: dnsmessage.Header{
   195  			ID:               4,
   196  			RecursionDesired: true,
   197  		},
   198  		Questions: []dnsmessage.Question{
   199  			{
   200  				Name:  name,
   201  				Type:  dnsmessage.TypeA,
   202  				Class: dnsmessage.ClassINET,
   203  			},
   204  		},
   205  		Additionals: []dnsmessage.Resource{
   206  			{
   207  				Header: rh,
   208  				Body:   &rb,
   209  			},
   210  		},
   211  	}
   212  	q4Pkt := qBuf[2:2]
   213  	q4Pkt, err = q4.AppendPack(q4Pkt)
   214  	if err != nil {
   215  		return
   216  	}
   217  	q4PktEnd := 2 + len(q4Pkt)
   218  
   219  	q6 := dnsmessage.Message{
   220  		Header: dnsmessage.Header{
   221  			ID:               6,
   222  			RecursionDesired: true,
   223  		},
   224  		Questions: []dnsmessage.Question{
   225  			{
   226  				Name:  name,
   227  				Type:  dnsmessage.TypeAAAA,
   228  				Class: dnsmessage.ClassINET,
   229  			},
   230  		},
   231  		Additionals: []dnsmessage.Resource{
   232  			{
   233  				Header: rh,
   234  				Body:   &rb,
   235  			},
   236  		},
   237  	}
   238  	q6PktStart := q4PktEnd + 2
   239  	q6Pkt := qBuf[q6PktStart:q6PktStart]
   240  	q6Pkt, err = q6.AppendPack(q6Pkt)
   241  	if err != nil {
   242  		return
   243  	}
   244  	q6PktEnd := q6PktStart + len(q6Pkt)
   245  
   246  	// Try UDP first if available.
   247  	if r.udpClient != nil {
   248  		result = r.sendQueriesUDP(ctx, nameString, q4Pkt, q6Pkt)
   249  
   250  		if ce := r.logger.Check(zap.DebugLevel, "DNS lookup sent queries via UDP"); ce != nil {
   251  			ce.Write(
   252  				zap.String("resolver", r.name),
   253  				zap.String("name", nameString),
   254  				zap.Bool("handled", result.isDone()),
   255  				zap.Stringers("v4", result.IPv4),
   256  				zap.Stringers("v6", result.IPv6),
   257  				zap.Time("ttl", result.TTL),
   258  			)
   259  		}
   260  	}
   261  
   262  	// Fallback to TCP if UDP failed or is unavailable.
   263  	if !result.isDone() && r.tcpClient != nil {
   264  		// Write length fields.
   265  		q4LenBuf := qBuf[:2]
   266  		q6LenBuf := qBuf[q4PktEnd:q6PktStart]
   267  		binary.BigEndian.PutUint16(q4LenBuf, uint16(len(q4Pkt)))
   268  		binary.BigEndian.PutUint16(q6LenBuf, uint16(len(q6Pkt)))
   269  
   270  		result = r.sendQueriesTCP(ctx, nameString, qBuf[:q6PktEnd])
   271  
   272  		if ce := r.logger.Check(zap.DebugLevel, "DNS lookup sent queries via TCP"); ce != nil {
   273  			ce.Write(
   274  				zap.String("resolver", r.name),
   275  				zap.String("name", nameString),
   276  				zap.Bool("handled", result.isDone()),
   277  				zap.Stringers("v4", result.IPv4),
   278  				zap.Stringers("v6", result.IPv6),
   279  				zap.Time("ttl", result.TTL),
   280  			)
   281  		}
   282  	}
   283  
   284  	if !result.isDone() {
   285  		err = ErrLookup
   286  		return
   287  	}
   288  
   289  	// Add result to cache if TTL hasn't expired.
   290  	if !result.HasExpired() {
   291  		r.mu.Lock()
   292  		r.cache[nameString] = result
   293  		r.mu.Unlock()
   294  	}
   295  
   296  	return
   297  }
   298  
   299  // sendQueriesUDP sends queries using the resolver's UDP client and returns the result and whether the lookup was successful.
   300  func (r *Resolver) sendQueriesUDP(ctx context.Context, nameString string, q4Pkt, q6Pkt []byte) (result Result) {
   301  	ctx, cancel := context.WithTimeout(ctx, lookupTimeout)
   302  	defer cancel()
   303  
   304  	clientInfo, clientSession, err := r.udpClient.NewSession(ctx)
   305  	if err != nil {
   306  		r.logger.Warn("Failed to create new UDP client session",
   307  			zap.String("resolver", r.name),
   308  			zap.Error(err),
   309  		)
   310  		return
   311  	}
   312  	defer clientSession.Close()
   313  
   314  	udpConn, err := clientInfo.ListenConfig.ListenUDP(ctx, "udp", "")
   315  	if err != nil {
   316  		r.logger.Warn("Failed to create UDP socket for DNS lookup",
   317  			zap.String("resolver", r.name),
   318  			zap.Error(err),
   319  		)
   320  		return
   321  	}
   322  	defer udpConn.Close()
   323  
   324  	go func() {
   325  		<-ctx.Done()
   326  		udpConn.SetReadDeadline(conn.ALongTimeAgo)
   327  	}()
   328  
   329  	// Spin up senders.
   330  	// Each sender will keep sending at 2s intervals until
   331  	// done unblocks or after 10 iterations.
   332  	sendFunc := func(pkt []byte, done <-chan struct{}) {
   333  		b := make([]byte, clientInfo.PackerHeadroom.Front+len(pkt)+clientInfo.PackerHeadroom.Rear)
   334  
   335  		for range 10 {
   336  			copy(b[clientInfo.PackerHeadroom.Front:], pkt)
   337  			destAddrPort, packetStart, packetLength, err := clientSession.Packer.PackInPlace(ctx, b, r.serverAddr, clientInfo.PackerHeadroom.Front, len(pkt))
   338  			if err != nil {
   339  				r.logger.Warn("Failed to pack UDP DNS query packet",
   340  					zap.String("resolver", r.name),
   341  					zap.String("name", nameString),
   342  					zap.Stringer("serverAddrPort", r.serverAddrPort),
   343  					zap.Error(err),
   344  				)
   345  				cancel()
   346  				return
   347  			}
   348  
   349  			_, err = udpConn.WriteToUDPAddrPort(b[packetStart:packetStart+packetLength], destAddrPort)
   350  			if err != nil {
   351  				r.logger.Warn("Failed to write UDP DNS query packet",
   352  					zap.String("resolver", r.name),
   353  					zap.String("name", nameString),
   354  					zap.Stringer("serverAddrPort", r.serverAddrPort),
   355  					zap.Stringer("destAddrPort", destAddrPort),
   356  					zap.Error(err),
   357  				)
   358  				cancel()
   359  				return
   360  			}
   361  
   362  			select {
   363  			case <-done:
   364  				return
   365  			case <-time.After(2 * time.Second):
   366  			}
   367  		}
   368  	}
   369  
   370  	ctx4, cancel4 := context.WithCancel(ctx)
   371  	ctx6, cancel6 := context.WithCancel(ctx)
   372  	defer cancel4()
   373  	defer cancel6()
   374  	go sendFunc(q4Pkt, ctx4.Done())
   375  	go sendFunc(q6Pkt, ctx6.Done())
   376  
   377  	// Receive replies.
   378  	recvBuf := make([]byte, clientSession.MaxPacketSize)
   379  
   380  	for {
   381  		n, _, flags, packetSourceAddress, err := udpConn.ReadMsgUDPAddrPort(recvBuf, nil)
   382  		if err != nil {
   383  			if errors.Is(err, os.ErrDeadlineExceeded) {
   384  				r.logger.Warn("DNS lookup via UDP timed out",
   385  					zap.String("resolver", r.name),
   386  					zap.String("name", nameString),
   387  					zap.Stringer("serverAddrPort", r.serverAddrPort),
   388  				)
   389  				break
   390  			}
   391  			r.logger.Warn("Failed to read UDP DNS response",
   392  				zap.String("resolver", r.name),
   393  				zap.String("name", nameString),
   394  				zap.Stringer("serverAddrPort", r.serverAddrPort),
   395  				zap.Stringer("packetSourceAddress", packetSourceAddress),
   396  				zap.Int("packetLength", n),
   397  				zap.Error(err),
   398  			)
   399  			continue
   400  		}
   401  		if err = conn.ParseFlagsForError(flags); err != nil {
   402  			r.logger.Warn("Failed to read UDP DNS response",
   403  				zap.String("resolver", r.name),
   404  				zap.String("name", nameString),
   405  				zap.Stringer("serverAddrPort", r.serverAddrPort),
   406  				zap.Stringer("packetSourceAddress", packetSourceAddress),
   407  				zap.Int("packetLength", n),
   408  				zap.Error(err),
   409  			)
   410  			continue
   411  		}
   412  
   413  		payloadSourceAddrPort, payloadStart, payloadLength, err := clientSession.Unpacker.UnpackInPlace(recvBuf, packetSourceAddress, 0, n)
   414  		if err != nil {
   415  			r.logger.Warn("Failed to unpack UDP DNS response packet",
   416  				zap.String("resolver", r.name),
   417  				zap.String("name", nameString),
   418  				zap.Stringer("serverAddrPort", r.serverAddrPort),
   419  				zap.Stringer("packetSourceAddress", packetSourceAddress),
   420  				zap.Int("packetLength", n),
   421  				zap.Error(err),
   422  			)
   423  			continue
   424  		}
   425  		if !conn.AddrPortMappedEqual(payloadSourceAddrPort, r.serverAddrPort) {
   426  			r.logger.Warn("Ignoring UDP DNS response packet from unknown server",
   427  				zap.String("resolver", r.name),
   428  				zap.String("name", nameString),
   429  				zap.Stringer("serverAddrPort", r.serverAddrPort),
   430  				zap.Stringer("payloadSourceAddrPort", payloadSourceAddrPort),
   431  			)
   432  			continue
   433  		}
   434  		msg := recvBuf[payloadStart : payloadStart+payloadLength]
   435  
   436  		header, err := result.parseMsg(msg)
   437  		if err != nil {
   438  			r.logger.Warn("Failed to parse UDP DNS response",
   439  				zap.String("resolver", r.name),
   440  				zap.String("name", nameString),
   441  				zap.Stringer("serverAddrPort", r.serverAddrPort),
   442  				zap.Error(err),
   443  			)
   444  			break
   445  		}
   446  		if header.Truncated {
   447  			if ce := r.logger.Check(zap.DebugLevel, "Received truncated UDP DNS response"); ce != nil {
   448  				ce.Write(
   449  					zap.String("resolver", r.name),
   450  					zap.String("name", nameString),
   451  					zap.Stringer("serverAddrPort", r.serverAddrPort),
   452  					zap.Uint16("transactionID", header.ID),
   453  				)
   454  			}
   455  			// Immediately fall back to TCP.
   456  			result.clearDone()
   457  			break
   458  		}
   459  
   460  		// Break out of loop if both v4 and v6 are done.
   461  		if result.isDone() {
   462  			break
   463  		}
   464  
   465  		switch header.ID {
   466  		case 4:
   467  			cancel4()
   468  		case 6:
   469  			cancel6()
   470  		}
   471  	}
   472  
   473  	return
   474  }
   475  
   476  // sendQueriesTCP sends queries using the resolver's TCP client and returns the result and whether the lookup was successful.
   477  func (r *Resolver) sendQueriesTCP(ctx context.Context, nameString string, queries []byte) (result Result) {
   478  	ctx, cancel := context.WithTimeout(ctx, lookupTimeout)
   479  	defer cancel()
   480  
   481  	// Write.
   482  	rawRW, rw, err := r.tcpClient.Dial(ctx, r.serverAddr, queries)
   483  	if err != nil {
   484  		r.logger.Warn("Failed to dial TCP DNS server",
   485  			zap.String("resolver", r.name),
   486  			zap.String("name", nameString),
   487  			zap.Stringer("serverAddrPort", r.serverAddrPort),
   488  			zap.Error(err),
   489  		)
   490  		return
   491  	}
   492  	defer rawRW.Close()
   493  
   494  	// Set read deadline.
   495  	if tc, ok := rawRW.(*net.TCPConn); ok {
   496  		go func() {
   497  			<-ctx.Done()
   498  			tc.SetReadDeadline(conn.ALongTimeAgo)
   499  		}()
   500  	}
   501  
   502  	// Read.
   503  	crw := zerocopy.NewCopyReadWriter(rw)
   504  	lengthBuf := make([]byte, 2)
   505  
   506  	for range 2 {
   507  		// Read length field.
   508  		_, err = io.ReadFull(crw, lengthBuf)
   509  		if err != nil {
   510  			r.logger.Warn("Failed to read TCP DNS response length",
   511  				zap.String("resolver", r.name),
   512  				zap.String("name", nameString),
   513  				zap.Stringer("serverAddrPort", r.serverAddrPort),
   514  				zap.Error(err),
   515  			)
   516  			return
   517  		}
   518  
   519  		msgLen := binary.BigEndian.Uint16(lengthBuf)
   520  		if msgLen == 0 {
   521  			r.logger.Warn("TCP DNS response length is zero",
   522  				zap.String("resolver", r.name),
   523  				zap.String("name", nameString),
   524  				zap.Stringer("serverAddrPort", r.serverAddrPort),
   525  			)
   526  			return
   527  		}
   528  
   529  		// Read message.
   530  		msg := make([]byte, msgLen)
   531  		_, err = io.ReadFull(crw, msg)
   532  		if err != nil {
   533  			r.logger.Warn("Failed to read TCP DNS response",
   534  				zap.String("resolver", r.name),
   535  				zap.String("name", nameString),
   536  				zap.Stringer("serverAddrPort", r.serverAddrPort),
   537  				zap.Error(err),
   538  			)
   539  			return
   540  		}
   541  
   542  		header, err := result.parseMsg(msg)
   543  		if err != nil {
   544  			r.logger.Warn("Failed to parse TCP DNS response",
   545  				zap.String("resolver", r.name),
   546  				zap.String("name", nameString),
   547  				zap.Stringer("serverAddrPort", r.serverAddrPort),
   548  				zap.Error(err),
   549  			)
   550  			return
   551  		}
   552  		if header.Truncated {
   553  			if ce := r.logger.Check(zap.DebugLevel, "Received truncated TCP DNS response"); ce != nil {
   554  				ce.Write(
   555  					zap.String("resolver", r.name),
   556  					zap.String("name", nameString),
   557  					zap.Stringer("serverAddrPort", r.serverAddrPort),
   558  					zap.Uint16("transactionID", header.ID),
   559  				)
   560  			}
   561  			// TCP DNS responses exceeding 65535 bytes are truncated.
   562  			// Use the truncated response like how Go std & the glibc resolver do.
   563  		}
   564  	}
   565  
   566  	return
   567  }
   568  
   569  func (r *Result) parseMsg(msg []byte) (dnsmessage.Header, error) {
   570  	var parser dnsmessage.Parser
   571  
   572  	// Parse header.
   573  	header, err := parser.Start(msg)
   574  	if err != nil {
   575  		return dnsmessage.Header{}, fmt.Errorf("failed to parse query response header: %w", err)
   576  	}
   577  
   578  	// Check transaction ID.
   579  	switch header.ID {
   580  	case 4:
   581  		if r.v4done {
   582  			return header, nil
   583  		}
   584  		r.IPv4 = r.IPv4[:0]
   585  	case 6:
   586  		if r.v6done {
   587  			return header, nil
   588  		}
   589  		r.IPv6 = r.IPv6[:0]
   590  	default:
   591  		return dnsmessage.Header{}, fmt.Errorf("unexpected transaction ID: %d", header.ID)
   592  	}
   593  
   594  	// Check response bit.
   595  	if !header.Response {
   596  		return dnsmessage.Header{}, ErrMessageNotResponse
   597  	}
   598  
   599  	// Continue parsing even if truncated.
   600  	// The caller may still want to use the result.
   601  
   602  	// Check RCode.
   603  	if header.RCode != dnsmessage.RCodeSuccess {
   604  		return dnsmessage.Header{}, fmt.Errorf("DNS failure: %s", header.RCode)
   605  	}
   606  
   607  	// Skip questions.
   608  	if err = parser.SkipAllQuestions(); err != nil {
   609  		return dnsmessage.Header{}, fmt.Errorf("failed to skip questions: %w", err)
   610  	}
   611  
   612  	// Parse answers and add to result.
   613  	for {
   614  		answerHeader, err := parser.AnswerHeader()
   615  		if err != nil {
   616  			if err == dnsmessage.ErrSectionDone {
   617  				break
   618  			}
   619  			return dnsmessage.Header{}, fmt.Errorf("failed to parse answer header: %w", err)
   620  		}
   621  
   622  		// Set minimum TTL.
   623  		ttl := time.Now().Add(time.Duration(answerHeader.TTL) * time.Second)
   624  		if r.TTL.IsZero() || r.TTL.After(ttl) {
   625  			r.TTL = ttl
   626  		}
   627  
   628  		// Skip non-A/AAAA RRs.
   629  		switch answerHeader.Type {
   630  		case dnsmessage.TypeA:
   631  			arr, err := parser.AResource()
   632  			if err != nil {
   633  				return dnsmessage.Header{}, fmt.Errorf("failed to parse A resource: %w", err)
   634  			}
   635  			r.IPv4 = append(r.IPv4, netip.AddrFrom4(arr.A))
   636  
   637  		case dnsmessage.TypeAAAA:
   638  			aaaarr, err := parser.AAAAResource()
   639  			if err != nil {
   640  				return dnsmessage.Header{}, fmt.Errorf("failed to parse AAAA resource: %w", err)
   641  			}
   642  			r.IPv6 = append(r.IPv6, netip.AddrFrom16(aaaarr.AAAA))
   643  
   644  		default:
   645  			if err = parser.SkipAnswer(); err != nil {
   646  				return dnsmessage.Header{}, fmt.Errorf("failed to skip answer: %w", err)
   647  			}
   648  		}
   649  	}
   650  
   651  	// Mark v4 or v6 as done.
   652  	switch header.ID {
   653  	case 4:
   654  		r.v4done = true
   655  	case 6:
   656  		r.v6done = true
   657  	}
   658  
   659  	return header, nil
   660  }
   661  
   662  func (r *Result) isDone() bool {
   663  	return r.v4done && r.v6done
   664  }
   665  
   666  func (r *Result) clearDone() {
   667  	r.v4done = false
   668  	r.v6done = false
   669  }
   670  
   671  // HasExpired returns true if the result's TTL has expired.
   672  func (r *Result) HasExpired() bool {
   673  	return r.TTL.Before(time.Now())
   674  }
   675  
   676  // Clone returns a deep copy of the result.
   677  // Modifying values in the address slices will not affect the original result.
   678  func (r *Result) Clone() Result {
   679  	return Result{
   680  		IPv4:   slices.Clone(r.IPv4),
   681  		IPv6:   slices.Clone(r.IPv6),
   682  		TTL:    r.TTL,
   683  		v4done: r.v4done,
   684  		v6done: r.v6done,
   685  	}
   686  }
   687  
   688  // Lookup looks up [name] and returns the result.
   689  func (r *Resolver) Lookup(ctx context.Context, name string) (Result, error) {
   690  	result, err := r.lookup(ctx, name)
   691  	if err != nil {
   692  		return Result{}, err
   693  	}
   694  	return result.Clone(), nil
   695  }
   696  
   697  // SimpleResolver defines methods that only return the resolved IP addresses.
   698  type SimpleResolver interface {
   699  	// LookupIP looks up [name] and returns one of the associated IP addresses.
   700  	LookupIP(ctx context.Context, name string) (netip.Addr, error)
   701  
   702  	// LookupIPs looks up [name] and returns all associated IP addresses.
   703  	LookupIPs(ctx context.Context, name string) ([]netip.Addr, error)
   704  }
   705  
   706  // LookupIP implements [SimpleResolver.LookupIP].
   707  func (r *Resolver) LookupIP(ctx context.Context, name string) (netip.Addr, error) {
   708  	result, err := r.lookup(ctx, name)
   709  	if err != nil {
   710  		return netip.Addr{}, err
   711  	}
   712  	if len(result.IPv6) > 0 {
   713  		return result.IPv6[0], nil
   714  	}
   715  	if len(result.IPv4) > 0 {
   716  		return result.IPv4[0], nil
   717  	}
   718  	return netip.Addr{}, ErrDomainNoAssociatedIPs
   719  }
   720  
   721  // LookupIPs implements [SimpleResolver.LookupIPs].
   722  func (r *Resolver) LookupIPs(ctx context.Context, name string) ([]netip.Addr, error) {
   723  	result, err := r.lookup(ctx, name)
   724  	if err != nil {
   725  		return nil, err
   726  	}
   727  
   728  	ips := make([]netip.Addr, 0, len(result.IPv6)+len(result.IPv4))
   729  	ips = append(ips, result.IPv6...)
   730  	ips = append(ips, result.IPv4...)
   731  	return ips, nil
   732  }
   733  
   734  // SystemResolver resolves names using [net.DefaultResolver].
   735  // It implements [SimpleResolver].
   736  type SystemResolver struct {
   737  	name   string
   738  	logger *zap.Logger
   739  }
   740  
   741  // NewSystemResolver returns a new [SystemResolver].
   742  func NewSystemResolver(name string, logger *zap.Logger) *SystemResolver {
   743  	return &SystemResolver{
   744  		name:   name,
   745  		logger: logger,
   746  	}
   747  }
   748  
   749  // LookupIP implements [SimpleResolver.LookupIP].
   750  func (r *SystemResolver) LookupIP(ctx context.Context, name string) (netip.Addr, error) {
   751  	ips, err := r.LookupIPs(ctx, name)
   752  	if err != nil {
   753  		return netip.Addr{}, err
   754  	}
   755  	if len(ips) == 0 {
   756  		return netip.Addr{}, ErrDomainNoAssociatedIPs
   757  	}
   758  	return ips[0], nil
   759  }
   760  
   761  // LookupIPs implements [SimpleResolver.LookupIPs].
   762  func (r *SystemResolver) LookupIPs(ctx context.Context, name string) ([]netip.Addr, error) {
   763  	ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", name)
   764  	if err != nil {
   765  		return nil, err
   766  	}
   767  
   768  	if ce := r.logger.Check(zap.DebugLevel, "DNS lookup got result from system resolver"); ce != nil {
   769  		ce.Write(
   770  			zap.String("resolver", r.name),
   771  			zap.String("name", name),
   772  			zap.Stringers("ips", ips),
   773  		)
   774  	}
   775  
   776  	return ips, nil
   777  }