github.com/xmplusdev/xray-core@v1.8.10/app/dns/dnscommon.go (about) 1 package dns 2 3 import ( 4 "context" 5 "encoding/binary" 6 "strings" 7 "time" 8 9 "github.com/xmplusdev/xray-core/common" 10 "github.com/xmplusdev/xray-core/common/errors" 11 "github.com/xmplusdev/xray-core/common/log" 12 "github.com/xmplusdev/xray-core/common/net" 13 "github.com/xmplusdev/xray-core/common/session" 14 "github.com/xmplusdev/xray-core/core" 15 dns_feature "github.com/xmplusdev/xray-core/features/dns" 16 "golang.org/x/net/dns/dnsmessage" 17 ) 18 19 // Fqdn normalizes domain make sure it ends with '.' 20 func Fqdn(domain string) string { 21 if len(domain) > 0 && strings.HasSuffix(domain, ".") { 22 return domain 23 } 24 return domain + "." 25 } 26 27 type record struct { 28 A *IPRecord 29 AAAA *IPRecord 30 } 31 32 // IPRecord is a cacheable item for a resolved domain 33 type IPRecord struct { 34 ReqID uint16 35 IP []net.Address 36 Expire time.Time 37 RCode dnsmessage.RCode 38 } 39 40 func (r *IPRecord) getIPs() ([]net.Address, error) { 41 if r == nil || r.Expire.Before(time.Now()) { 42 return nil, errRecordNotFound 43 } 44 if r.RCode != dnsmessage.RCodeSuccess { 45 return nil, dns_feature.RCodeError(r.RCode) 46 } 47 return r.IP, nil 48 } 49 50 func isNewer(baseRec *IPRecord, newRec *IPRecord) bool { 51 if newRec == nil { 52 return false 53 } 54 if baseRec == nil { 55 return true 56 } 57 return baseRec.Expire.Before(newRec.Expire) 58 } 59 60 var errRecordNotFound = errors.New("record not found") 61 62 type dnsRequest struct { 63 reqType dnsmessage.Type 64 domain string 65 start time.Time 66 expire time.Time 67 msg *dnsmessage.Message 68 } 69 70 func genEDNS0Options(clientIP net.IP) *dnsmessage.Resource { 71 if len(clientIP) == 0 { 72 return nil 73 } 74 75 var netmask int 76 var family uint16 77 78 if len(clientIP) == 4 { 79 family = 1 80 netmask = 24 // 24 for IPV4, 96 for IPv6 81 } else { 82 family = 2 83 netmask = 96 84 } 85 86 b := make([]byte, 4) 87 binary.BigEndian.PutUint16(b[0:], family) 88 b[2] = byte(netmask) 89 b[3] = 0 90 switch family { 91 case 1: 92 ip := clientIP.To4().Mask(net.CIDRMask(netmask, net.IPv4len*8)) 93 needLength := (netmask + 8 - 1) / 8 // division rounding up 94 b = append(b, ip[:needLength]...) 95 case 2: 96 ip := clientIP.Mask(net.CIDRMask(netmask, net.IPv6len*8)) 97 needLength := (netmask + 8 - 1) / 8 // division rounding up 98 b = append(b, ip[:needLength]...) 99 } 100 101 const EDNS0SUBNET = 0x08 102 103 opt := new(dnsmessage.Resource) 104 common.Must(opt.Header.SetEDNS0(1350, 0xfe00, true)) 105 106 opt.Body = &dnsmessage.OPTResource{ 107 Options: []dnsmessage.Option{ 108 { 109 Code: EDNS0SUBNET, 110 Data: b, 111 }, 112 }, 113 } 114 115 return opt 116 } 117 118 func buildReqMsgs(domain string, option dns_feature.IPOption, reqIDGen func() uint16, reqOpts *dnsmessage.Resource) []*dnsRequest { 119 qA := dnsmessage.Question{ 120 Name: dnsmessage.MustNewName(domain), 121 Type: dnsmessage.TypeA, 122 Class: dnsmessage.ClassINET, 123 } 124 125 qAAAA := dnsmessage.Question{ 126 Name: dnsmessage.MustNewName(domain), 127 Type: dnsmessage.TypeAAAA, 128 Class: dnsmessage.ClassINET, 129 } 130 131 var reqs []*dnsRequest 132 now := time.Now() 133 134 if option.IPv4Enable { 135 msg := new(dnsmessage.Message) 136 msg.Header.ID = reqIDGen() 137 msg.Header.RecursionDesired = true 138 msg.Questions = []dnsmessage.Question{qA} 139 if reqOpts != nil { 140 msg.Additionals = append(msg.Additionals, *reqOpts) 141 } 142 reqs = append(reqs, &dnsRequest{ 143 reqType: dnsmessage.TypeA, 144 domain: domain, 145 start: now, 146 msg: msg, 147 }) 148 } 149 150 if option.IPv6Enable { 151 msg := new(dnsmessage.Message) 152 msg.Header.ID = reqIDGen() 153 msg.Header.RecursionDesired = true 154 msg.Questions = []dnsmessage.Question{qAAAA} 155 if reqOpts != nil { 156 msg.Additionals = append(msg.Additionals, *reqOpts) 157 } 158 reqs = append(reqs, &dnsRequest{ 159 reqType: dnsmessage.TypeAAAA, 160 domain: domain, 161 start: now, 162 msg: msg, 163 }) 164 } 165 166 return reqs 167 } 168 169 // parseResponse parses DNS answers from the returned payload 170 func parseResponse(payload []byte) (*IPRecord, error) { 171 var parser dnsmessage.Parser 172 h, err := parser.Start(payload) 173 if err != nil { 174 return nil, newError("failed to parse DNS response").Base(err).AtWarning() 175 } 176 if err := parser.SkipAllQuestions(); err != nil { 177 return nil, newError("failed to skip questions in DNS response").Base(err).AtWarning() 178 } 179 180 now := time.Now() 181 ipRecord := &IPRecord{ 182 ReqID: h.ID, 183 RCode: h.RCode, 184 Expire: now.Add(time.Second * 600), 185 } 186 187 L: 188 for { 189 ah, err := parser.AnswerHeader() 190 if err != nil { 191 if err != dnsmessage.ErrSectionDone { 192 newError("failed to parse answer section for domain: ", ah.Name.String()).Base(err).WriteToLog() 193 } 194 break 195 } 196 197 ttl := ah.TTL 198 if ttl == 0 { 199 ttl = 600 200 } 201 expire := now.Add(time.Duration(ttl) * time.Second) 202 if ipRecord.Expire.After(expire) { 203 ipRecord.Expire = expire 204 } 205 206 switch ah.Type { 207 case dnsmessage.TypeA: 208 ans, err := parser.AResource() 209 if err != nil { 210 newError("failed to parse A record for domain: ", ah.Name).Base(err).WriteToLog() 211 break L 212 } 213 ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.A[:])) 214 case dnsmessage.TypeAAAA: 215 ans, err := parser.AAAAResource() 216 if err != nil { 217 newError("failed to parse AAAA record for domain: ", ah.Name).Base(err).WriteToLog() 218 break L 219 } 220 ipRecord.IP = append(ipRecord.IP, net.IPAddress(ans.AAAA[:])) 221 default: 222 if err := parser.SkipAnswer(); err != nil { 223 newError("failed to skip answer").Base(err).WriteToLog() 224 break L 225 } 226 continue 227 } 228 } 229 230 return ipRecord, nil 231 } 232 233 // toDnsContext create a new background context with parent inbound, session and dns log 234 func toDnsContext(ctx context.Context, addr string) context.Context { 235 dnsCtx := core.ToBackgroundDetachedContext(ctx) 236 if inbound := session.InboundFromContext(ctx); inbound != nil { 237 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 238 } 239 dnsCtx = session.ContextWithContent(dnsCtx, session.ContentFromContext(ctx)) 240 dnsCtx = log.ContextWithAccessMessage(dnsCtx, &log.AccessMessage{ 241 From: "DNS", 242 To: addr, 243 Status: log.AccessAccepted, 244 Reason: "", 245 }) 246 return dnsCtx 247 }