github.com/moqsien/xraycore@v1.8.5/app/dns/nameserver_udp.go (about) 1 package dns 2 3 import ( 4 "context" 5 "strings" 6 "sync" 7 "sync/atomic" 8 "time" 9 10 "github.com/moqsien/xraycore/common" 11 "github.com/moqsien/xraycore/common/log" 12 "github.com/moqsien/xraycore/common/net" 13 "github.com/moqsien/xraycore/common/protocol/dns" 14 udp_proto "github.com/moqsien/xraycore/common/protocol/udp" 15 "github.com/moqsien/xraycore/common/session" 16 "github.com/moqsien/xraycore/common/signal/pubsub" 17 "github.com/moqsien/xraycore/common/task" 18 dns_feature "github.com/moqsien/xraycore/features/dns" 19 "github.com/moqsien/xraycore/features/routing" 20 "github.com/moqsien/xraycore/transport/internet/udp" 21 "golang.org/x/net/dns/dnsmessage" 22 ) 23 24 // ClassicNameServer implemented traditional UDP DNS. 25 type ClassicNameServer struct { 26 sync.RWMutex 27 name string 28 address *net.Destination 29 ips map[string]*record 30 requests map[uint16]*dnsRequest 31 pub *pubsub.Service 32 udpServer *udp.Dispatcher 33 cleanup *task.Periodic 34 reqID uint32 35 } 36 37 // NewClassicNameServer creates udp server object for remote resolving. 38 func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher) *ClassicNameServer { 39 // default to 53 if unspecific 40 if address.Port == 0 { 41 address.Port = net.Port(53) 42 } 43 44 s := &ClassicNameServer{ 45 address: &address, 46 ips: make(map[string]*record), 47 requests: make(map[uint16]*dnsRequest), 48 pub: pubsub.NewService(), 49 name: strings.ToUpper(address.String()), 50 } 51 s.cleanup = &task.Periodic{ 52 Interval: time.Minute, 53 Execute: s.Cleanup, 54 } 55 s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse) 56 newError("DNS: created UDP client initialized for ", address.NetAddr()).AtInfo().WriteToLog() 57 return s 58 } 59 60 // Name implements Server. 61 func (s *ClassicNameServer) Name() string { 62 return s.name 63 } 64 65 // Cleanup clears expired items from cache 66 func (s *ClassicNameServer) Cleanup() error { 67 now := time.Now() 68 s.Lock() 69 defer s.Unlock() 70 71 if len(s.ips) == 0 && len(s.requests) == 0 { 72 return newError(s.name, " nothing to do. stopping...") 73 } 74 75 for domain, record := range s.ips { 76 if record.A != nil && record.A.Expire.Before(now) { 77 record.A = nil 78 } 79 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 80 record.AAAA = nil 81 } 82 83 if record.A == nil && record.AAAA == nil { 84 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 85 delete(s.ips, domain) 86 } else { 87 s.ips[domain] = record 88 } 89 } 90 91 if len(s.ips) == 0 { 92 s.ips = make(map[string]*record) 93 } 94 95 for id, req := range s.requests { 96 if req.expire.Before(now) { 97 delete(s.requests, id) 98 } 99 } 100 101 if len(s.requests) == 0 { 102 s.requests = make(map[uint16]*dnsRequest) 103 } 104 105 return nil 106 } 107 108 // HandleResponse handles udp response packet from remote DNS server. 109 func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) { 110 ipRec, err := parseResponse(packet.Payload.Bytes()) 111 if err != nil { 112 newError(s.name, " fail to parse responded DNS udp").AtError().WriteToLog() 113 return 114 } 115 116 s.Lock() 117 id := ipRec.ReqID 118 req, ok := s.requests[id] 119 if ok { 120 // remove the pending request 121 delete(s.requests, id) 122 } 123 s.Unlock() 124 if !ok { 125 newError(s.name, " cannot find the pending request").AtError().WriteToLog() 126 return 127 } 128 129 var rec record 130 switch req.reqType { 131 case dnsmessage.TypeA: 132 rec.A = ipRec 133 case dnsmessage.TypeAAAA: 134 rec.AAAA = ipRec 135 } 136 137 elapsed := time.Since(req.start) 138 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 139 if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) { 140 s.updateIP(req.domain, &rec) 141 } 142 } 143 144 func (s *ClassicNameServer) updateIP(domain string, newRec *record) { 145 s.Lock() 146 147 rec, found := s.ips[domain] 148 if !found { 149 rec = &record{} 150 } 151 152 updated := false 153 if isNewer(rec.A, newRec.A) { 154 rec.A = newRec.A 155 updated = true 156 } 157 if isNewer(rec.AAAA, newRec.AAAA) { 158 rec.AAAA = newRec.AAAA 159 updated = true 160 } 161 162 if updated { 163 newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog() 164 s.ips[domain] = rec 165 } 166 if newRec.A != nil { 167 s.pub.Publish(domain+"4", nil) 168 } 169 if newRec.AAAA != nil { 170 s.pub.Publish(domain+"6", nil) 171 } 172 s.Unlock() 173 common.Must(s.cleanup.Start()) 174 } 175 176 func (s *ClassicNameServer) newReqID() uint16 { 177 return uint16(atomic.AddUint32(&s.reqID, 1)) 178 } 179 180 func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) { 181 s.Lock() 182 defer s.Unlock() 183 184 id := req.msg.ID 185 req.expire = time.Now().Add(time.Second * 8) 186 s.requests[id] = req 187 } 188 189 func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 190 newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 191 192 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 193 194 for _, req := range reqs { 195 s.addPendingRequest(req) 196 b, _ := dns.PackMessage(req.msg) 197 s.udpServer.Dispatch(toDnsContext(ctx, s.address.String()), *s.address, b) 198 } 199 } 200 201 func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 202 s.RLock() 203 record, found := s.ips[domain] 204 s.RUnlock() 205 206 if !found { 207 return nil, errRecordNotFound 208 } 209 210 var err4 error 211 var err6 error 212 var ips []net.Address 213 var ip6 []net.Address 214 215 if option.IPv4Enable { 216 ips, err4 = record.A.getIPs() 217 } 218 219 if option.IPv6Enable { 220 ip6, err6 = record.AAAA.getIPs() 221 ips = append(ips, ip6...) 222 } 223 224 if len(ips) > 0 { 225 return toNetIP(ips) 226 } 227 228 if err4 != nil { 229 return nil, err4 230 } 231 232 if err6 != nil { 233 return nil, err6 234 } 235 236 return nil, dns_feature.ErrEmptyResponse 237 } 238 239 // QueryIP implements Server. 240 func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { 241 fqdn := Fqdn(domain) 242 243 if disableCache { 244 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 245 } else { 246 ips, err := s.findIPsForDomain(fqdn, option) 247 if err != errRecordNotFound { 248 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 249 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) 250 return ips, err 251 } 252 } 253 254 // ipv4 and ipv6 belong to different subscription groups 255 var sub4, sub6 *pubsub.Subscriber 256 if option.IPv4Enable { 257 sub4 = s.pub.Subscribe(fqdn + "4") 258 defer sub4.Close() 259 } 260 if option.IPv6Enable { 261 sub6 = s.pub.Subscribe(fqdn + "6") 262 defer sub6.Close() 263 } 264 done := make(chan interface{}) 265 go func() { 266 if sub4 != nil { 267 select { 268 case <-sub4.Wait(): 269 case <-ctx.Done(): 270 } 271 } 272 if sub6 != nil { 273 select { 274 case <-sub6.Wait(): 275 case <-ctx.Done(): 276 } 277 } 278 close(done) 279 }() 280 s.sendQuery(ctx, fqdn, clientIP, option) 281 start := time.Now() 282 283 for { 284 ips, err := s.findIPsForDomain(fqdn, option) 285 if err != errRecordNotFound { 286 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) 287 return ips, err 288 } 289 290 select { 291 case <-ctx.Done(): 292 return nil, ctx.Err() 293 case <-done: 294 } 295 } 296 }