github.com/EagleQL/Xray-core@v1.4.3/app/dns/dnscommon.go (about)

     1  package dns
     2  
     3  import (
     4  	"encoding/binary"
     5  	"strings"
     6  	"time"
     7  
     8  	"github.com/xtls/xray-core/common"
     9  	"github.com/xtls/xray-core/common/errors"
    10  	"github.com/xtls/xray-core/common/net"
    11  	dns_feature "github.com/xtls/xray-core/features/dns"
    12  	"golang.org/x/net/dns/dnsmessage"
    13  )
    14  
    15  // Fqdn normalize domain make sure it ends with '.'
    16  func Fqdn(domain string) string {
    17  	if len(domain) > 0 && strings.HasSuffix(domain, ".") {
    18  		return domain
    19  	}
    20  	return domain + "."
    21  }
    22  
    23  type record struct {
    24  	A    *IPRecord
    25  	AAAA *IPRecord
    26  }
    27  
    28  // IPRecord is a cacheable item for a resolved domain
    29  type IPRecord struct {
    30  	ReqID  uint16
    31  	IP     []net.Address
    32  	Expire time.Time
    33  	RCode  dnsmessage.RCode
    34  }
    35  
    36  func (r *IPRecord) getIPs() ([]net.Address, error) {
    37  	if r == nil || r.Expire.Before(time.Now()) {
    38  		return nil, errRecordNotFound
    39  	}
    40  	if r.RCode != dnsmessage.RCodeSuccess {
    41  		return nil, dns_feature.RCodeError(r.RCode)
    42  	}
    43  	return r.IP, nil
    44  }
    45  
    46  func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
    47  	if newRec == nil {
    48  		return false
    49  	}
    50  	if baseRec == nil {
    51  		return true
    52  	}
    53  	return baseRec.Expire.Before(newRec.Expire)
    54  }
    55  
    56  var (
    57  	errRecordNotFound = errors.New("record not found")
    58  )
    59  
    60  type dnsRequest struct {
    61  	reqType dnsmessage.Type
    62  	domain  string
    63  	start   time.Time
    64  	expire  time.Time
    65  	msg     *dnsmessage.Message
    66  }
    67  
    68  func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
    69  	if len(clientIP) == 0 {
    70  		return nil
    71  	}
    72  
    73  	var netmask int
    74  	var family uint16
    75  
    76  	if len(clientIP) == 4 {
    77  		family = 1
    78  		netmask = 24 // 24 for IPV4, 96 for IPv6
    79  	} else {
    80  		family = 2
    81  		netmask = 96
    82  	}
    83  
    84  	b := make([]byte, 4)
    85  	binary.BigEndian.PutUint16(b[0:], family)
    86  	b[2] = byte(netmask)
    87  	b[3] = 0
    88  	switch family {
    89  	case 1:
    90  		ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
    91  		needLength := (netmask + 8 - 1) / 8 // division rounding up
    92  		b = append(b, ip[:needLength]...)
    93  	case 2:
    94  		ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
    95  		needLength := (netmask + 8 - 1) / 8 // division rounding up
    96  		b = append(b, ip[:needLength]...)
    97  	}
    98  
    99  	const EDNS0SUBNET = 0x08
   100  
   101  	opt := new(dnsmessage.Resource)
   102  	common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
   103  
   104  	opt.Body = &dnsmessage.OPTResource{
   105  		Options: []dnsmessage.Option{
   106  			{
   107  				Code: EDNS0SUBNET,
   108  				Data: b,
   109  			},
   110  		},
   111  	}
   112  
   113  	return opt
   114  }
   115  
   116  func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
   117  	qA := dnsmessage.Question{
   118  		Name:  dnsmessage.MustNewName(domain),
   119  		Type:  dnsmessage.TypeA,
   120  		Class: dnsmessage.ClassINET,
   121  	}
   122  
   123  	qAAAA := dnsmessage.Question{
   124  		Name:  dnsmessage.MustNewName(domain),
   125  		Type:  dnsmessage.TypeAAAA,
   126  		Class: dnsmessage.ClassINET,
   127  	}
   128  
   129  	var reqs []*dnsRequest
   130  	now := time.Now()
   131  
   132  	if option.IPv4Enable {
   133  		msg := new(dnsmessage.Message)
   134  		msg.Header.ID = reqIDGen()
   135  		msg.Header.RecursionDesired = true
   136  		msg.Questions = []dnsmessage.Question{qA}
   137  		if reqOpts != nil {
   138  			msg.Additionals = append(msg.Additionals, *reqOpts)
   139  		}
   140  		reqs = append(reqs, &dnsRequest{
   141  			reqType: dnsmessage.TypeA,
   142  			domain:  domain,
   143  			start:   now,
   144  			msg:     msg,
   145  		})
   146  	}
   147  
   148  	if option.IPv6Enable {
   149  		msg := new(dnsmessage.Message)
   150  		msg.Header.ID = reqIDGen()
   151  		msg.Header.RecursionDesired = true
   152  		msg.Questions = []dnsmessage.Question{qAAAA}
   153  		if reqOpts != nil {
   154  			msg.Additionals = append(msg.Additionals, *reqOpts)
   155  		}
   156  		reqs = append(reqs, &dnsRequest{
   157  			reqType: dnsmessage.TypeAAAA,
   158  			domain:  domain,
   159  			start:   now,
   160  			msg:     msg,
   161  		})
   162  	}
   163  
   164  	return reqs
   165  }
   166  
   167  // parseResponse parse DNS answers from the returned payload
   168  func parseResponse(payload []byte) (*IPRecord, error) {
   169  	var parser dnsmessage.Parser
   170  	h, err := parser.Start(payload)
   171  	if err != nil {
   172  		return nil, newError("failed to parse DNS response").Base(err).AtWarning()
   173  	}
   174  	if err := parser.SkipAllQuestions(); err != nil {
   175  		return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning()
   176  	}
   177  
   178  	now := time.Now()
   179  	ipRecord := &IPRecord{
   180  		ReqID:  h.ID,
   181  		RCode:  h.RCode,
   182  		Expire: now.Add(time.Second * 600),
   183  	}
   184  
   185  L:
   186  	for {
   187  		ah, err := parser.AnswerHeader()
   188  		if err != nil {
   189  			if err != dnsmessage.ErrSectionDone {
   190  				newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog()
   191  			}
   192  			break
   193  		}
   194  
   195  		ttl := ah.TTL
   196  		if ttl == 0 {
   197  			ttl = 600
   198  		}
   199  		expire := now.Add(time.Duration(ttl) * time.Second)
   200  		if ipRecord.Expire.After(expire) {
   201  			ipRecord.Expire = expire
   202  		}
   203  
   204  		switch ah.Type {
   205  		case dnsmessage.TypeA:
   206  			ans, err := parser.AResource()
   207  			if err != nil {
   208  				newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
   209  				break L
   210  			}
   211  			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
   212  		case dnsmessage.TypeAAAA:
   213  			ans, err := parser.AAAAResource()
   214  			if err != nil {
   215  				newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
   216  				break L
   217  			}
   218  			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
   219  		default:
   220  			if err := parser.SkipAnswer(); err != nil {
   221  				newError("failed to skip answer").Base(err).WriteToLog()
   222  				break L
   223  			}
   224  			continue
   225  		}
   226  	}
   227  
   228  	return ipRecord, nil
   229  }