github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/app/dns/dnscommon.go (about)

     1  package dns
     2  
     3  import (
     4  	"encoding/binary"
     5  	"strings"
     6  	"time"
     7  
     8  	"golang.org/x/net/dns/dnsmessage"
     9  
    10  	"github.com/v2fly/v2ray-core/v5/common"
    11  	"github.com/v2fly/v2ray-core/v5/common/errors"
    12  	"github.com/v2fly/v2ray-core/v5/common/net"
    13  	dns_feature "github.com/v2fly/v2ray-core/v5/features/dns"
    14  )
    15  
    16  // Fqdn normalizes domain make sure it ends with '.'
    17  func Fqdn(domain string) string {
    18  	if len(domain) > 0 && strings.HasSuffix(domain, ".") {
    19  		return domain
    20  	}
    21  	return domain + "."
    22  }
    23  
    24  type record struct {
    25  	A    *IPRecord
    26  	AAAA *IPRecord
    27  }
    28  
    29  // IPRecord is a cacheable item for a resolved domain
    30  type IPRecord struct {
    31  	ReqID  uint16
    32  	IP     []net.Address
    33  	Expire time.Time
    34  	RCode  dnsmessage.RCode
    35  }
    36  
    37  func (r *IPRecord) getIPs() ([]net.Address, error) {
    38  	if r == nil || r.Expire.Before(time.Now()) {
    39  		return nil, errRecordNotFound
    40  	}
    41  	if r.RCode != dnsmessage.RCodeSuccess {
    42  		return nil, dns_feature.RCodeError(r.RCode)
    43  	}
    44  	return r.IP, nil
    45  }
    46  
    47  func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
    48  	if newRec == nil {
    49  		return false
    50  	}
    51  	if baseRec == nil {
    52  		return true
    53  	}
    54  	return baseRec.Expire.Before(newRec.Expire)
    55  }
    56  
    57  var errRecordNotFound = errors.New("record not found")
    58  
    59  type dnsRequest struct {
    60  	reqType dnsmessage.Type
    61  	domain  string
    62  	start   time.Time
    63  	expire  time.Time
    64  	msg     *dnsmessage.Message
    65  }
    66  
    67  func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
    68  	if len(clientIP) == 0 {
    69  		return nil
    70  	}
    71  
    72  	var netmask int
    73  	var family uint16
    74  
    75  	if len(clientIP) == 4 {
    76  		family = 1
    77  		netmask = 24 // 24 for IPV4, 96 for IPv6
    78  	} else {
    79  		family = 2
    80  		netmask = 96
    81  	}
    82  
    83  	b := make([]byte, 4)
    84  	binary.BigEndian.PutUint16(b[0:], family)
    85  	b[2] = byte(netmask)
    86  	b[3] = 0
    87  	switch family {
    88  	case 1:
    89  		ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
    90  		needLength := (netmask + 8 - 1) / 8 // division rounding up
    91  		b = append(b, ip[:needLength]...)
    92  	case 2:
    93  		ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
    94  		needLength := (netmask + 8 - 1) / 8 // division rounding up
    95  		b = append(b, ip[:needLength]...)
    96  	}
    97  
    98  	const EDNS0SUBNET = 0x08
    99  
   100  	opt := new(dnsmessage.Resource)
   101  	common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
   102  
   103  	opt.Body = &dnsmessage.OPTResource{
   104  		Options: []dnsmessage.Option{
   105  			{
   106  				Code: EDNS0SUBNET,
   107  				Data: b,
   108  			},
   109  		},
   110  	}
   111  
   112  	return opt
   113  }
   114  
   115  func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
   116  	qA := dnsmessage.Question{
   117  		Name:  dnsmessage.MustNewName(domain),
   118  		Type:  dnsmessage.TypeA,
   119  		Class: dnsmessage.ClassINET,
   120  	}
   121  
   122  	qAAAA := dnsmessage.Question{
   123  		Name:  dnsmessage.MustNewName(domain),
   124  		Type:  dnsmessage.TypeAAAA,
   125  		Class: dnsmessage.ClassINET,
   126  	}
   127  
   128  	var reqs []*dnsRequest
   129  	now := time.Now()
   130  
   131  	if option.IPv4Enable {
   132  		msg := new(dnsmessage.Message)
   133  		msg.Header.ID = reqIDGen()
   134  		msg.Header.RecursionDesired = true
   135  		msg.Questions = []dnsmessage.Question{qA}
   136  		if reqOpts != nil {
   137  			msg.Additionals = append(msg.Additionals, *reqOpts)
   138  		}
   139  		reqs = append(reqs, &dnsRequest{
   140  			reqType: dnsmessage.TypeA,
   141  			domain:  domain,
   142  			start:   now,
   143  			msg:     msg,
   144  		})
   145  	}
   146  
   147  	if option.IPv6Enable {
   148  		msg := new(dnsmessage.Message)
   149  		msg.Header.ID = reqIDGen()
   150  		msg.Header.RecursionDesired = true
   151  		msg.Questions = []dnsmessage.Question{qAAAA}
   152  		if reqOpts != nil {
   153  			msg.Additionals = append(msg.Additionals, *reqOpts)
   154  		}
   155  		reqs = append(reqs, &dnsRequest{
   156  			reqType: dnsmessage.TypeAAAA,
   157  			domain:  domain,
   158  			start:   now,
   159  			msg:     msg,
   160  		})
   161  	}
   162  
   163  	return reqs
   164  }
   165  
   166  // parseResponse parses DNS answers from the returned payload
   167  func parseResponse(payload []byte) (*IPRecord, error) {
   168  	var parser dnsmessage.Parser
   169  	h, err := parser.Start(payload)
   170  	if err != nil {
   171  		return nil, newError("failed to parse DNS response").Base(err).AtWarning()
   172  	}
   173  	if err := parser.SkipAllQuestions(); err != nil {
   174  		return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning()
   175  	}
   176  
   177  	now := time.Now()
   178  	ipRecord := &IPRecord{
   179  		ReqID:  h.ID,
   180  		RCode:  h.RCode,
   181  		Expire: now.Add(time.Second * 600),
   182  	}
   183  
   184  L:
   185  	for {
   186  		ah, err := parser.AnswerHeader()
   187  		if err != nil {
   188  			if err != dnsmessage.ErrSectionDone {
   189  				newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog()
   190  			}
   191  			break
   192  		}
   193  
   194  		ttl := ah.TTL
   195  		if ttl == 0 {
   196  			ttl = 600
   197  		}
   198  		expire := now.Add(time.Duration(ttl) * time.Second)
   199  		if ipRecord.Expire.After(expire) {
   200  			ipRecord.Expire = expire
   201  		}
   202  
   203  		switch ah.Type {
   204  		case dnsmessage.TypeA:
   205  			ans, err := parser.AResource()
   206  			if err != nil {
   207  				newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
   208  				break L
   209  			}
   210  			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
   211  		case dnsmessage.TypeAAAA:
   212  			ans, err := parser.AAAAResource()
   213  			if err != nil {
   214  				newError("failed to parse AAAA record for domain: ", ah.Name).Base(err).WriteToLog()
   215  				break L
   216  			}
   217  			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
   218  		default:
   219  			if err := parser.SkipAnswer(); err != nil {
   220  				newError("failed to skip answer").Base(err).WriteToLog()
   221  				break L
   222  			}
   223  			continue
   224  		}
   225  	}
   226  
   227  	return ipRecord, nil
   228  }
   229  
   230  func filterIP(ips []net.Address, option dns_feature.IPOption) []net.Address {
   231  	filtered := make([]net.Address, 0, len(ips))
   232  	for _, ip := range ips {
   233  		if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) {
   234  			filtered = append(filtered, ip)
   235  		}
   236  	}
   237  	return filtered
   238  }