github.com/moqsien/xraycore@v1.8.5/app/dns/dnscommon.go (about)

     1  package dns
     2  
     3  import (
     4  	"context"
     5  	"encoding/binary"
     6  	"strings"
     7  	"time"
     8  
     9  	"github.com/moqsien/xraycore/common"
    10  	"github.com/moqsien/xraycore/common/errors"
    11  	"github.com/moqsien/xraycore/common/log"
    12  	"github.com/moqsien/xraycore/common/net"
    13  	"github.com/moqsien/xraycore/common/session"
    14  	"github.com/moqsien/xraycore/core"
    15  	dns_feature "github.com/moqsien/xraycore/features/dns"
    16  	"golang.org/x/net/dns/dnsmessage"
    17  )
    18  
    19  // Fqdn normalizes domain make sure it ends with '.'
    20  func Fqdn(domain string) string {
    21  	if len(domain) > 0 && strings.HasSuffix(domain, ".") {
    22  		return domain
    23  	}
    24  	return domain + "."
    25  }
    26  
    27  type record struct {
    28  	A    *IPRecord
    29  	AAAA *IPRecord
    30  }
    31  
    32  // IPRecord is a cacheable item for a resolved domain
    33  type IPRecord struct {
    34  	ReqID  uint16
    35  	IP     []net.Address
    36  	Expire time.Time
    37  	RCode  dnsmessage.RCode
    38  }
    39  
    40  func (r *IPRecord) getIPs() ([]net.Address, error) {
    41  	if r == nil || r.Expire.Before(time.Now()) {
    42  		return nil, errRecordNotFound
    43  	}
    44  	if r.RCode != dnsmessage.RCodeSuccess {
    45  		return nil, dns_feature.RCodeError(r.RCode)
    46  	}
    47  	return r.IP, nil
    48  }
    49  
    50  func isNewer(baseRec *IPRecord, newRec *IPRecord) bool {
    51  	if newRec == nil {
    52  		return false
    53  	}
    54  	if baseRec == nil {
    55  		return true
    56  	}
    57  	return baseRec.Expire.Before(newRec.Expire)
    58  }
    59  
    60  var errRecordNotFound = errors.New("record not found")
    61  
    62  type dnsRequest struct {
    63  	reqType dnsmessage.Type
    64  	domain  string
    65  	start   time.Time
    66  	expire  time.Time
    67  	msg     *dnsmessage.Message
    68  }
    69  
    70  func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource {
    71  	if len(clientIP) == 0 {
    72  		return nil
    73  	}
    74  
    75  	var netmask int
    76  	var family uint16
    77  
    78  	if len(clientIP) == 4 {
    79  		family = 1
    80  		netmask = 24 // 24 for IPV4, 96 for IPv6
    81  	} else {
    82  		family = 2
    83  		netmask = 96
    84  	}
    85  
    86  	b := make([]byte, 4)
    87  	binary.BigEndian.PutUint16(b[0:], family)
    88  	b[2] = byte(netmask)
    89  	b[3] = 0
    90  	switch family {
    91  	case 1:
    92  		ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8))
    93  		needLength := (netmask + 8 - 1) / 8 // division rounding up
    94  		b = append(b, ip[:needLength]...)
    95  	case 2:
    96  		ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8))
    97  		needLength := (netmask + 8 - 1) / 8 // division rounding up
    98  		b = append(b, ip[:needLength]...)
    99  	}
   100  
   101  	const EDNS0SUBNET = 0x08
   102  
   103  	opt := new(dnsmessage.Resource)
   104  	common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true))
   105  
   106  	opt.Body = &dnsmessage.OPTResource{
   107  		Options: []dnsmessage.Option{
   108  			{
   109  				Code: EDNS0SUBNET,
   110  				Data: b,
   111  			},
   112  		},
   113  	}
   114  
   115  	return opt
   116  }
   117  
   118  func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest {
   119  	qA := dnsmessage.Question{
   120  		Name:  dnsmessage.MustNewName(domain),
   121  		Type:  dnsmessage.TypeA,
   122  		Class: dnsmessage.ClassINET,
   123  	}
   124  
   125  	qAAAA := dnsmessage.Question{
   126  		Name:  dnsmessage.MustNewName(domain),
   127  		Type:  dnsmessage.TypeAAAA,
   128  		Class: dnsmessage.ClassINET,
   129  	}
   130  
   131  	var reqs []*dnsRequest
   132  	now := time.Now()
   133  
   134  	if option.IPv4Enable {
   135  		msg := new(dnsmessage.Message)
   136  		msg.Header.ID = reqIDGen()
   137  		msg.Header.RecursionDesired = true
   138  		msg.Questions = []dnsmessage.Question{qA}
   139  		if reqOpts != nil {
   140  			msg.Additionals = append(msg.Additionals, *reqOpts)
   141  		}
   142  		reqs = append(reqs, &dnsRequest{
   143  			reqType: dnsmessage.TypeA,
   144  			domain:  domain,
   145  			start:   now,
   146  			msg:     msg,
   147  		})
   148  	}
   149  
   150  	if option.IPv6Enable {
   151  		msg := new(dnsmessage.Message)
   152  		msg.Header.ID = reqIDGen()
   153  		msg.Header.RecursionDesired = true
   154  		msg.Questions = []dnsmessage.Question{qAAAA}
   155  		if reqOpts != nil {
   156  			msg.Additionals = append(msg.Additionals, *reqOpts)
   157  		}
   158  		reqs = append(reqs, &dnsRequest{
   159  			reqType: dnsmessage.TypeAAAA,
   160  			domain:  domain,
   161  			start:   now,
   162  			msg:     msg,
   163  		})
   164  	}
   165  
   166  	return reqs
   167  }
   168  
   169  // parseResponse parses DNS answers from the returned payload
   170  func parseResponse(payload []byte) (*IPRecord, error) {
   171  	var parser dnsmessage.Parser
   172  	h, err := parser.Start(payload)
   173  	if err != nil {
   174  		return nil, newError("failed to parse DNS response").Base(err).AtWarning()
   175  	}
   176  	if err := parser.SkipAllQuestions(); err != nil {
   177  		return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning()
   178  	}
   179  
   180  	now := time.Now()
   181  	ipRecord := &IPRecord{
   182  		ReqID:  h.ID,
   183  		RCode:  h.RCode,
   184  		Expire: now.Add(time.Second * 600),
   185  	}
   186  
   187  L:
   188  	for {
   189  		ah, err := parser.AnswerHeader()
   190  		if err != nil {
   191  			if err != dnsmessage.ErrSectionDone {
   192  				newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog()
   193  			}
   194  			break
   195  		}
   196  
   197  		ttl := ah.TTL
   198  		if ttl == 0 {
   199  			ttl = 600
   200  		}
   201  		expire := now.Add(time.Duration(ttl) * time.Second)
   202  		if ipRecord.Expire.After(expire) {
   203  			ipRecord.Expire = expire
   204  		}
   205  
   206  		switch ah.Type {
   207  		case dnsmessage.TypeA:
   208  			ans, err := parser.AResource()
   209  			if err != nil {
   210  				newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog()
   211  				break L
   212  			}
   213  			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:]))
   214  		case dnsmessage.TypeAAAA:
   215  			ans, err := parser.AAAAResource()
   216  			if err != nil {
   217  				newError("failed to parse AAAA record for domain: ", ah.Name).Base(err).WriteToLog()
   218  				break L
   219  			}
   220  			ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:]))
   221  		default:
   222  			if err := parser.SkipAnswer(); err != nil {
   223  				newError("failed to skip answer").Base(err).WriteToLog()
   224  				break L
   225  			}
   226  			continue
   227  		}
   228  	}
   229  
   230  	return ipRecord, nil
   231  }
   232  
   233  // toDnsContext create a new background context with parent inbound, session and dns log
   234  func toDnsContext(ctx context.Context, addr string) context.Context {
   235  	dnsCtx := core.ToBackgroundDetachedContext(ctx)
   236  	if inbound := session.InboundFromContext(ctx); inbound != nil {
   237  		dnsCtx = session.ContextWithInbound(dnsCtx, inbound)
   238  	}
   239  	dnsCtx = session.ContextWithContent(dnsCtx, session.ContentFromContext(ctx))
   240  	dnsCtx = log.ContextWithAccessMessage(dnsCtx, &log.AccessMessage{
   241  		From:   "DNS",
   242  		To:     addr,
   243  		Status: log.AccessAccepted,
   244  		Reason: "",
   245  	})
   246  	return dnsCtx
   247  }