github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/app/dns/dnscommon.go (about) 1 package dns 2 3 import ( 4 "encoding/binary" 5 "strings" 6 "time" 7 8 "golang.org/x/net/dns/dnsmessage" 9 10 "github.com/v2fly/v2ray-core/v5/common" 11 "github.com/v2fly/v2ray-core/v5/common/errors" 12 "github.com/v2fly/v2ray-core/v5/common/net" 13 dns_feature "github.com/v2fly/v2ray-core/v5/features/dns" 14 ) 15 16 // Fqdn normalizes domain make sure it ends with '.' 17 func Fqdn(domain string) string { 18 if len(domain) > 0 && strings.HasSuffix(domain, ".") { 19 return domain 20 } 21 return domain + "." 22 } 23 24 type record struct { 25 A *IPRecord 26 AAAA *IPRecord 27 } 28 29 // IPRecord is a cacheable item for a resolved domain 30 type IPRecord struct { 31 ReqID uint16 32 IP []net.Address 33 Expire time.Time 34 RCode dnsmessage.RCode 35 } 36 37 func (r *IPRecord) getIPs() ([]net.Address, error) { 38 if r == nil || r.Expire.Before(time.Now()) { 39 return nil, errRecordNotFound 40 } 41 if r.RCode != dnsmessage.RCodeSuccess { 42 return nil, dns_feature.RCodeError(r.RCode) 43 } 44 return r.IP, nil 45 } 46 47 func isNewer(baseRec *IPRecord, newRec *IPRecord) bool { 48 if newRec == nil { 49 return false 50 } 51 if baseRec == nil { 52 return true 53 } 54 return baseRec.Expire.Before(newRec.Expire) 55 } 56 57 var errRecordNotFound = errors.New("record not found") 58 59 type dnsRequest struct { 60 reqType dnsmessage.Type 61 domain string 62 start time.Time 63 expire time.Time 64 msg *dnsmessage.Message 65 } 66 67 func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource { 68 if len(clientIP) == 0 { 69 return nil 70 } 71 72 var netmask int 73 var family uint16 74 75 if len(clientIP) == 4 { 76 family = 1 77 netmask = 24 // 24 for IPV4, 96 for IPv6 78 } else { 79 family = 2 80 netmask = 96 81 } 82 83 b := make([]byte, 4) 84 binary.BigEndian.PutUint16(b[0:], family) 85 b[2] = byte(netmask) 86 b[3] = 0 87 switch family { 88 case 1: 89 ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8)) 90 needLength := (netmask + 8 - 1) / 8 // division rounding up 91 b = append(b, ip[:needLength]...) 92 case 2: 93 ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8)) 94 needLength := (netmask + 8 - 1) / 8 // division rounding up 95 b = append(b, ip[:needLength]...) 96 } 97 98 const EDNS0SUBNET = 0x08 99 100 opt := new(dnsmessage.Resource) 101 common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true)) 102 103 opt.Body = &dnsmessage.OPTResource{ 104 Options: []dnsmessage.Option{ 105 { 106 Code: EDNS0SUBNET, 107 Data: b, 108 }, 109 }, 110 } 111 112 return opt 113 } 114 115 func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest { 116 qA := dnsmessage.Question{ 117 Name: dnsmessage.MustNewName(domain), 118 Type: dnsmessage.TypeA, 119 Class: dnsmessage.ClassINET, 120 } 121 122 qAAAA := dnsmessage.Question{ 123 Name: dnsmessage.MustNewName(domain), 124 Type: dnsmessage.TypeAAAA, 125 Class: dnsmessage.ClassINET, 126 } 127 128 var reqs []*dnsRequest 129 now := time.Now() 130 131 if option.IPv4Enable { 132 msg := new(dnsmessage.Message) 133 msg.Header.ID = reqIDGen() 134 msg.Header.RecursionDesired = true 135 msg.Questions = []dnsmessage.Question{qA} 136 if reqOpts != nil { 137 msg.Additionals = append(msg.Additionals, *reqOpts) 138 } 139 reqs = append(reqs, &dnsRequest{ 140 reqType: dnsmessage.TypeA, 141 domain: domain, 142 start: now, 143 msg: msg, 144 }) 145 } 146 147 if option.IPv6Enable { 148 msg := new(dnsmessage.Message) 149 msg.Header.ID = reqIDGen() 150 msg.Header.RecursionDesired = true 151 msg.Questions = []dnsmessage.Question{qAAAA} 152 if reqOpts != nil { 153 msg.Additionals = append(msg.Additionals, *reqOpts) 154 } 155 reqs = append(reqs, &dnsRequest{ 156 reqType: dnsmessage.TypeAAAA, 157 domain: domain, 158 start: now, 159 msg: msg, 160 }) 161 } 162 163 return reqs 164 } 165 166 // parseResponse parses DNS answers from the returned payload 167 func parseResponse(payload []byte) (*IPRecord, error) { 168 var parser dnsmessage.Parser 169 h, err := parser.Start(payload) 170 if err != nil { 171 return nil, newError("failed to parse DNS response").Base(err).AtWarning() 172 } 173 if err := parser.SkipAllQuestions(); err != nil { 174 return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning() 175 } 176 177 now := time.Now() 178 ipRecord := &IPRecord{ 179 ReqID: h.ID, 180 RCode: h.RCode, 181 Expire: now.Add(time.Second * 600), 182 } 183 184 L: 185 for { 186 ah, err := parser.AnswerHeader() 187 if err != nil { 188 if err != dnsmessage.ErrSectionDone { 189 newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog() 190 } 191 break 192 } 193 194 ttl := ah.TTL 195 if ttl == 0 { 196 ttl = 600 197 } 198 expire := now.Add(time.Duration(ttl) * time.Second) 199 if ipRecord.Expire.After(expire) { 200 ipRecord.Expire = expire 201 } 202 203 switch ah.Type { 204 case dnsmessage.TypeA: 205 ans, err := parser.AResource() 206 if err != nil { 207 newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog() 208 break L 209 } 210 ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:])) 211 case dnsmessage.TypeAAAA: 212 ans, err := parser.AAAAResource() 213 if err != nil { 214 newError("failed to parse AAAA record for domain: ", ah.Name).Base(err).WriteToLog() 215 break L 216 } 217 ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:])) 218 default: 219 if err := parser.SkipAnswer(); err != nil { 220 newError("failed to skip answer").Base(err).WriteToLog() 221 break L 222 } 223 continue 224 } 225 } 226 227 return ipRecord, nil 228 } 229 230 func filterIP(ips []net.Address, option dns_feature.IPOption) []net.Address { 231 filtered := make([]net.Address, 0, len(ips)) 232 for _, ip := range ips { 233 if (ip.Family().IsIPv4() && option.IPv4Enable) || (ip.Family().IsIPv6() && option.IPv6Enable) { 234 filtered = append(filtered, ip) 235 } 236 } 237 return filtered 238 }