github.com/EagleQL/Xray-core@v1.4.3/app/dns/dnscommon.go (about) 1 package dns 2 3 import ( 4 "encoding/binary" 5 "strings" 6 "time" 7 8 "github.com/xtls/xray-core/common" 9 "github.com/xtls/xray-core/common/errors" 10 "github.com/xtls/xray-core/common/net" 11 dns_feature "github.com/xtls/xray-core/features/dns" 12 "golang.org/x/net/dns/dnsmessage" 13 ) 14 15 // Fqdn normalize domain make sure it ends with '.' 16 func Fqdn(domain string) string { 17 if len(domain) > 0 && strings.HasSuffix(domain, ".") { 18 return domain 19 } 20 return domain + "." 21 } 22 23 type record struct { 24 A *IPRecord 25 AAAA *IPRecord 26 } 27 28 // IPRecord is a cacheable item for a resolved domain 29 type IPRecord struct { 30 ReqID uint16 31 IP []net.Address 32 Expire time.Time 33 RCode dnsmessage.RCode 34 } 35 36 func (r *IPRecord) getIPs() ([]net.Address, error) { 37 if r == nil || r.Expire.Before(time.Now()) { 38 return nil, errRecordNotFound 39 } 40 if r.RCode != dnsmessage.RCodeSuccess { 41 return nil, dns_feature.RCodeError(r.RCode) 42 } 43 return r.IP, nil 44 } 45 46 func isNewer(baseRec *IPRecord, newRec *IPRecord) bool { 47 if newRec == nil { 48 return false 49 } 50 if baseRec == nil { 51 return true 52 } 53 return baseRec.Expire.Before(newRec.Expire) 54 } 55 56 var ( 57 errRecordNotFound = errors.New("record not found") 58 ) 59 60 type dnsRequest struct { 61 reqType dnsmessage.Type 62 domain string 63 start time.Time 64 expire time.Time 65 msg *dnsmessage.Message 66 } 67 68 func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource { 69 if len(clientIP) == 0 { 70 return nil 71 } 72 73 var netmask int 74 var family uint16 75 76 if len(clientIP) == 4 { 77 family = 1 78 netmask = 24 // 24 for IPV4, 96 for IPv6 79 } else { 80 family = 2 81 netmask = 96 82 } 83 84 b := make([]byte, 4) 85 binary.BigEndian.PutUint16(b[0:], family) 86 b[2] = byte(netmask) 87 b[3] = 0 88 switch family { 89 case 1: 90 ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8)) 91 needLength := (netmask + 8 - 1) / 8 // division rounding up 92 b = append(b, ip[:needLength]...) 93 case 2: 94 ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8)) 95 needLength := (netmask + 8 - 1) / 8 // division rounding up 96 b = append(b, ip[:needLength]...) 97 } 98 99 const EDNS0SUBNET = 0x08 100 101 opt := new(dnsmessage.Resource) 102 common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true)) 103 104 opt.Body = &dnsmessage.OPTResource{ 105 Options: []dnsmessage.Option{ 106 { 107 Code: EDNS0SUBNET, 108 Data: b, 109 }, 110 }, 111 } 112 113 return opt 114 } 115 116 func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest { 117 qA := dnsmessage.Question{ 118 Name: dnsmessage.MustNewName(domain), 119 Type: dnsmessage.TypeA, 120 Class: dnsmessage.ClassINET, 121 } 122 123 qAAAA := dnsmessage.Question{ 124 Name: dnsmessage.MustNewName(domain), 125 Type: dnsmessage.TypeAAAA, 126 Class: dnsmessage.ClassINET, 127 } 128 129 var reqs []*dnsRequest 130 now := time.Now() 131 132 if option.IPv4Enable { 133 msg := new(dnsmessage.Message) 134 msg.Header.ID = reqIDGen() 135 msg.Header.RecursionDesired = true 136 msg.Questions = []dnsmessage.Question{qA} 137 if reqOpts != nil { 138 msg.Additionals = append(msg.Additionals, *reqOpts) 139 } 140 reqs = append(reqs, &dnsRequest{ 141 reqType: dnsmessage.TypeA, 142 domain: domain, 143 start: now, 144 msg: msg, 145 }) 146 } 147 148 if option.IPv6Enable { 149 msg := new(dnsmessage.Message) 150 msg.Header.ID = reqIDGen() 151 msg.Header.RecursionDesired = true 152 msg.Questions = []dnsmessage.Question{qAAAA} 153 if reqOpts != nil { 154 msg.Additionals = append(msg.Additionals, *reqOpts) 155 } 156 reqs = append(reqs, &dnsRequest{ 157 reqType: dnsmessage.TypeAAAA, 158 domain: domain, 159 start: now, 160 msg: msg, 161 }) 162 } 163 164 return reqs 165 } 166 167 // parseResponse parse DNS answers from the returned payload 168 func parseResponse(payload []byte) (*IPRecord, error) { 169 var parser dnsmessage.Parser 170 h, err := parser.Start(payload) 171 if err != nil { 172 return nil, newError("failed to parse DNS response").Base(err).AtWarning() 173 } 174 if err := parser.SkipAllQuestions(); err != nil { 175 return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning() 176 } 177 178 now := time.Now() 179 ipRecord := &IPRecord{ 180 ReqID: h.ID, 181 RCode: h.RCode, 182 Expire: now.Add(time.Second * 600), 183 } 184 185 L: 186 for { 187 ah, err := parser.AnswerHeader() 188 if err != nil { 189 if err != dnsmessage.ErrSectionDone { 190 newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog() 191 } 192 break 193 } 194 195 ttl := ah.TTL 196 if ttl == 0 { 197 ttl = 600 198 } 199 expire := now.Add(time.Duration(ttl) * time.Second) 200 if ipRecord.Expire.After(expire) { 201 ipRecord.Expire = expire 202 } 203 204 switch ah.Type { 205 case dnsmessage.TypeA: 206 ans, err := parser.AResource() 207 if err != nil { 208 newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog() 209 break L 210 } 211 ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:])) 212 case dnsmessage.TypeAAAA: 213 ans, err := parser.AAAAResource() 214 if err != nil { 215 newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog() 216 break L 217 } 218 ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:])) 219 default: 220 if err := parser.SkipAnswer(); err != nil { 221 newError("failed to skip answer").Base(err).WriteToLog() 222 break L 223 } 224 continue 225 } 226 } 227 228 return ipRecord, nil 229 }