github.com/v2fly/v2ray-core/v4@v4.45.2/app/dns/nameserver_udp.go (about) 1 //go:build !confonly 2 // +build !confonly 3 4 package dns 5 6 import ( 7 "context" 8 "strings" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "golang.org/x/net/dns/dnsmessage" 14 15 core "github.com/v2fly/v2ray-core/v4" 16 "github.com/v2fly/v2ray-core/v4/common" 17 "github.com/v2fly/v2ray-core/v4/common/net" 18 "github.com/v2fly/v2ray-core/v4/common/protocol/dns" 19 udp_proto "github.com/v2fly/v2ray-core/v4/common/protocol/udp" 20 "github.com/v2fly/v2ray-core/v4/common/session" 21 "github.com/v2fly/v2ray-core/v4/common/signal/pubsub" 22 "github.com/v2fly/v2ray-core/v4/common/task" 23 dns_feature "github.com/v2fly/v2ray-core/v4/features/dns" 24 "github.com/v2fly/v2ray-core/v4/features/routing" 25 "github.com/v2fly/v2ray-core/v4/transport/internet/udp" 26 ) 27 28 // ClassicNameServer implemented traditional UDP DNS. 29 type ClassicNameServer struct { 30 sync.RWMutex 31 name string 32 address *net.Destination 33 ips map[string]*record 34 requests map[uint16]*dnsRequest 35 pub *pubsub.Service 36 udpServer *udp.Dispatcher 37 cleanup *task.Periodic 38 reqID uint32 39 } 40 41 // NewClassicNameServer creates udp server object for remote resolving. 42 func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher) *ClassicNameServer { 43 // default to 53 if unspecific 44 if address.Port == 0 { 45 address.Port = net.Port(53) 46 } 47 48 s := &ClassicNameServer{ 49 address: &address, 50 ips: make(map[string]*record), 51 requests: make(map[uint16]*dnsRequest), 52 pub: pubsub.NewService(), 53 name: strings.ToUpper(address.String()), 54 } 55 s.cleanup = &task.Periodic{ 56 Interval: time.Minute, 57 Execute: s.Cleanup, 58 } 59 s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse) 60 newError("DNS: created UDP client initialized for ", address.NetAddr()).AtInfo().WriteToLog() 61 return s 62 } 63 64 // Name implements Server. 65 func (s *ClassicNameServer) Name() string { 66 return s.name 67 } 68 69 // Cleanup clears expired items from cache 70 func (s *ClassicNameServer) Cleanup() error { 71 now := time.Now() 72 s.Lock() 73 defer s.Unlock() 74 75 if len(s.ips) == 0 && len(s.requests) == 0 { 76 return newError(s.name, " nothing to do. stopping...") 77 } 78 79 for domain, record := range s.ips { 80 if record.A != nil && record.A.Expire.Before(now) { 81 record.A = nil 82 } 83 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 84 record.AAAA = nil 85 } 86 87 if record.A == nil && record.AAAA == nil { 88 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 89 delete(s.ips, domain) 90 } else { 91 s.ips[domain] = record 92 } 93 } 94 95 if len(s.ips) == 0 { 96 s.ips = make(map[string]*record) 97 } 98 99 for id, req := range s.requests { 100 if req.expire.Before(now) { 101 delete(s.requests, id) 102 } 103 } 104 105 if len(s.requests) == 0 { 106 s.requests = make(map[uint16]*dnsRequest) 107 } 108 109 return nil 110 } 111 112 // HandleResponse handles udp response packet from remote DNS server. 113 func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) { 114 ipRec, err := parseResponse(packet.Payload.Bytes()) 115 if err != nil { 116 newError(s.name, " fail to parse responded DNS udp").AtError().WriteToLog() 117 return 118 } 119 120 s.Lock() 121 id := ipRec.ReqID 122 req, ok := s.requests[id] 123 if ok { 124 // remove the pending request 125 delete(s.requests, id) 126 } 127 s.Unlock() 128 if !ok { 129 newError(s.name, " cannot find the pending request").AtError().WriteToLog() 130 return 131 } 132 133 var rec record 134 switch req.reqType { 135 case dnsmessage.TypeA: 136 rec.A = ipRec 137 case dnsmessage.TypeAAAA: 138 rec.AAAA = ipRec 139 } 140 141 elapsed := time.Since(req.start) 142 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 143 if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) { 144 s.updateIP(req.domain, &rec) 145 } 146 } 147 148 func (s *ClassicNameServer) updateIP(domain string, newRec *record) { 149 s.Lock() 150 151 rec, found := s.ips[domain] 152 if !found { 153 rec = &record{} 154 } 155 156 updated := false 157 if isNewer(rec.A, newRec.A) { 158 rec.A = newRec.A 159 updated = true 160 } 161 if isNewer(rec.AAAA, newRec.AAAA) { 162 rec.AAAA = newRec.AAAA 163 updated = true 164 } 165 166 if updated { 167 newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog() 168 s.ips[domain] = rec 169 } 170 if newRec.A != nil { 171 s.pub.Publish(domain+"4", nil) 172 } 173 if newRec.AAAA != nil { 174 s.pub.Publish(domain+"6", nil) 175 } 176 s.Unlock() 177 common.Must(s.cleanup.Start()) 178 } 179 180 func (s *ClassicNameServer) newReqID() uint16 { 181 return uint16(atomic.AddUint32(&s.reqID, 1)) 182 } 183 184 func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) { 185 s.Lock() 186 defer s.Unlock() 187 188 id := req.msg.ID 189 req.expire = time.Now().Add(time.Second * 8) 190 s.requests[id] = req 191 } 192 193 func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 194 newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 195 196 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 197 198 for _, req := range reqs { 199 s.addPendingRequest(req) 200 b, _ := dns.PackMessage(req.msg) 201 udpCtx := core.ToBackgroundDetachedContext(ctx) 202 if inbound := session.InboundFromContext(ctx); inbound != nil { 203 udpCtx = session.ContextWithInbound(udpCtx, inbound) 204 } 205 udpCtx = session.ContextWithContent(udpCtx, &session.Content{ 206 Protocol: "dns", 207 }) 208 s.udpServer.Dispatch(udpCtx, *s.address, b) 209 } 210 } 211 212 func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 213 s.RLock() 214 record, found := s.ips[domain] 215 s.RUnlock() 216 217 if !found { 218 return nil, errRecordNotFound 219 } 220 221 var err4 error 222 var err6 error 223 var ips []net.Address 224 var ip6 []net.Address 225 226 if option.IPv4Enable { 227 ips, err4 = record.A.getIPs() 228 } 229 230 if option.IPv6Enable { 231 ip6, err6 = record.AAAA.getIPs() 232 ips = append(ips, ip6...) 233 } 234 235 if len(ips) > 0 { 236 return toNetIP(ips) 237 } 238 239 if err4 != nil { 240 return nil, err4 241 } 242 243 if err6 != nil { 244 return nil, err6 245 } 246 247 return nil, dns_feature.ErrEmptyResponse 248 } 249 250 // QueryIP implements Server. 251 func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { 252 fqdn := Fqdn(domain) 253 254 if disableCache { 255 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 256 } else { 257 ips, err := s.findIPsForDomain(fqdn, option) 258 if err != errRecordNotFound { 259 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 260 return ips, err 261 } 262 } 263 264 // ipv4 and ipv6 belong to different subscription groups 265 var sub4, sub6 *pubsub.Subscriber 266 if option.IPv4Enable { 267 sub4 = s.pub.Subscribe(fqdn + "4") 268 defer sub4.Close() 269 } 270 if option.IPv6Enable { 271 sub6 = s.pub.Subscribe(fqdn + "6") 272 defer sub6.Close() 273 } 274 done := make(chan interface{}) 275 go func() { 276 if sub4 != nil { 277 select { 278 case <-sub4.Wait(): 279 case <-ctx.Done(): 280 } 281 } 282 if sub6 != nil { 283 select { 284 case <-sub6.Wait(): 285 case <-ctx.Done(): 286 } 287 } 288 close(done) 289 }() 290 s.sendQuery(ctx, fqdn, clientIP, option) 291 292 for { 293 ips, err := s.findIPsForDomain(fqdn, option) 294 if err != errRecordNotFound { 295 return ips, err 296 } 297 298 select { 299 case <-ctx.Done(): 300 return nil, ctx.Err() 301 case <-done: 302 } 303 } 304 }