github.com/Uhtred009/v2ray-core-1@v4.31.2+incompatible/app/dns/dnscommon.go (about)

     1  // +build !confonly
     2  
     3  package dns
     4  
     5  import (
     6  	"encoding/binary"
     7  	"time"
     8  
     9  	"golang.org/x/net/dns/dnsmessage"
    10  	"v2ray.com/core/common"
    11  	"v2ray.com/core/common/errors"
    12  	"v2ray.com/core/common/net"
    13  	dns_feature "v2ray.com/core/features/dns"
    14  )
    15  
    16  // Fqdn normalize domain make sure it ends with '.'
    17  func Fqdn(domain string) string {
    18  	if len(domain) > 0 && domain[len(domain)-1] == '.' {
    19  		return domain
    20  	}
    21  	return domain + "."
    22  }
    23  
    24  type record struct {
    25  	A    *IPRecord
    26  	AAAA *IPRecord
    27  }
    28  
    29  // IPRecord is a cacheable item for a resolved domain
    30  type IPRecord struct {
    31  	ReqID  uint16
    32  	IP     []net.Address
    33  	Expire time.Time
    34  	RCode  dnsmessage.RCode
    35  }
    36  
    37  func (r *IPRecord) getIPs() ([]net.Address, error) {
    38  	if r == nil || r.Expire.Before(time.Now()) {
    39  		return nil, errRecordNotFound
    40  	}
    41  	if r.RCode != dnsmessage.RCodeSuccess {
    42  		return nil, dns_feature.RCodeError(r.RCode)
    43  	}
    44  	return r.IP, nil
    45  }
    46  
    47  func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
    48  	if newRec == nil {
    49  		return false
    50  	}
    51  	if baseRec == nil {
    52  		return true
    53  	}
    54  	return baseRec.Expire.Before(newRec.Expire)
    55  }
    56  
    57  var (
    58  	errRecordNotFound = errors.New("record not found")
    59  )
    60  
    61  type dnsRequest struct {
    62  	reqType dnsmessage.Type
    63  	domain  string
    64  	start   time.Time
    65  	expire  time.Time
    66  	msg     *dnsmessage.Message
    67  }
    68  
    69  func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
    70  	if len(clientIP) == 0 {
    71  		return nil
    72  	}
    73  
    74  	var netmask int
    75  	var family uint16
    76  
    77  	if len(clientIP) == 4 {
    78  		family = 1
    79  		netmask = 24 // 24 for IPV4, 96 for IPv6
    80  	} else {
    81  		family = 2
    82  		netmask = 96
    83  	}
    84  
    85  	b := make([]byte, 4)
    86  	binary.BigEndian.PutUint16(b[0:], family)
    87  	b[2] = byte(netmask)
    88  	b[3] = 0
    89  	switch family {
    90  	case 1:
    91  		ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
    92  		needLength := (netmask + 8 - 1) / 8 // division rounding up
    93  		b = append(b, ip[:needLength]...)
    94  	case 2:
    95  		ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
    96  		needLength := (netmask + 8 - 1) / 8 // division rounding up
    97  		b = append(b, ip[:needLength]...)
    98  	}
    99  
   100  	const EDNS0SUBNET = 0x08
   101  
   102  	opt := new(dnsmessage.Resource)
   103  	common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
   104  
   105  	opt.Body = &dnsmessage.OPTResource{
   106  		Options: []dnsmessage.Option{
   107  			{
   108  				Code: EDNS0SUBNET,
   109  				Data: b,
   110  			},
   111  		},
   112  	}
   113  
   114  	return opt
   115  }
   116  
   117  func buildReqMsgs(domain string, option IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
   118  	qA := dnsmessage.Question{
   119  		Name:  dnsmessage.MustNewName(domain),
   120  		Type:  dnsmessage.TypeA,
   121  		Class: dnsmessage.ClassINET,
   122  	}
   123  
   124  	qAAAA := dnsmessage.Question{
   125  		Name:  dnsmessage.MustNewName(domain),
   126  		Type:  dnsmessage.TypeAAAA,
   127  		Class: dnsmessage.ClassINET,
   128  	}
   129  
   130  	var reqs []*dnsRequest
   131  	now := time.Now()
   132  
   133  	if option.IPv4Enable {
   134  		msg := new(dnsmessage.Message)
   135  		msg.Header.ID = reqIDGen()
   136  		msg.Header.RecursionDesired = true
   137  		msg.Questions = []dnsmessage.Question{qA}
   138  		if reqOpts != nil {
   139  			msg.Additionals = append(msg.Additionals, *reqOpts)
   140  		}
   141  		reqs = append(reqs, &dnsRequest{
   142  			reqType: dnsmessage.TypeA,
   143  			domain:  domain,
   144  			start:   now,
   145  			msg:     msg,
   146  		})
   147  	}
   148  
   149  	if option.IPv6Enable {
   150  		msg := new(dnsmessage.Message)
   151  		msg.Header.ID = reqIDGen()
   152  		msg.Header.RecursionDesired = true
   153  		msg.Questions = []dnsmessage.Question{qAAAA}
   154  		if reqOpts != nil {
   155  			msg.Additionals = append(msg.Additionals, *reqOpts)
   156  		}
   157  		reqs = append(reqs, &dnsRequest{
   158  			reqType: dnsmessage.TypeAAAA,
   159  			domain:  domain,
   160  			start:   now,
   161  			msg:     msg,
   162  		})
   163  	}
   164  
   165  	return reqs
   166  }
   167  
   168  // parseResponse parse DNS answers from the returned payload
   169  func parseResponse(payload []byte) (*IPRecord, error) {
   170  	var parser dnsmessage.Parser
   171  	h, err := parser.Start(payload)
   172  	if err != nil {
   173  		return nil, newError("failed to parse DNS response").Base(err).AtWarning()
   174  	}
   175  	if err := parser.SkipAllQuestions(); err != nil {
   176  		return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning()
   177  	}
   178  
   179  	now := time.Now()
   180  	ipRecord := &IPRecord{
   181  		ReqID:  h.ID,
   182  		RCode:  h.RCode,
   183  		Expire: now.Add(time.Second * 600),
   184  	}
   185  
   186  L:
   187  	for {
   188  		ah, err := parser.AnswerHeader()
   189  		if err != nil {
   190  			if err != dnsmessage.ErrSectionDone {
   191  				newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog()
   192  			}
   193  			break
   194  		}
   195  
   196  		ttl := ah.TTL
   197  		if ttl == 0 {
   198  			ttl = 600
   199  		}
   200  		expire := now.Add(time.Duration(ttl) * time.Second)
   201  		if ipRecord.Expire.After(expire) {
   202  			ipRecord.Expire = expire
   203  		}
   204  
   205  		switch ah.Type {
   206  		case dnsmessage.TypeA:
   207  			ans, err := parser.AResource()
   208  			if err != nil {
   209  				newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
   210  				break L
   211  			}
   212  			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
   213  		case dnsmessage.TypeAAAA:
   214  			ans, err := parser.AAAAResource()
   215  			if err != nil {
   216  				newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
   217  				break L
   218  			}
   219  			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
   220  		default:
   221  			if err := parser.SkipAnswer(); err != nil {
   222  				newError("failed to skip answer").Base(err).WriteToLog()
   223  				break L
   224  			}
   225  			continue
   226  		}
   227  	}
   228  
   229  	return ipRecord, nil
   230  }