github.com/v2fly/v2ray-core/v5@v5.16.2-0.20240507031116-8191faa6e095/app/dns/nameserver_udp.go (about) 1 //go:build !confonly 2 // +build !confonly 3 4 package dns 5 6 import ( 7 "context" 8 "strings" 9 "sync" 10 "sync/atomic" 11 "time" 12 13 "golang.org/x/net/dns/dnsmessage" 14 15 core "github.com/v2fly/v2ray-core/v5" 16 "github.com/v2fly/v2ray-core/v5/common" 17 "github.com/v2fly/v2ray-core/v5/common/net" 18 "github.com/v2fly/v2ray-core/v5/common/protocol/dns" 19 udp_proto "github.com/v2fly/v2ray-core/v5/common/protocol/udp" 20 "github.com/v2fly/v2ray-core/v5/common/session" 21 "github.com/v2fly/v2ray-core/v5/common/signal/pubsub" 22 "github.com/v2fly/v2ray-core/v5/common/task" 23 dns_feature "github.com/v2fly/v2ray-core/v5/features/dns" 24 "github.com/v2fly/v2ray-core/v5/features/routing" 25 "github.com/v2fly/v2ray-core/v5/transport/internet/udp" 26 ) 27 28 // ClassicNameServer implemented traditional UDP DNS. 29 type ClassicNameServer struct { 30 sync.RWMutex 31 name string 32 address net.Destination 33 ips map[string]record 34 requests map[uint16]dnsRequest 35 pub *pubsub.Service 36 udpServer udp.DispatcherI 37 cleanup *task.Periodic 38 reqID uint32 39 } 40 41 // NewClassicNameServer creates udp server object for remote resolving. 42 func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher) *ClassicNameServer { 43 // default to 53 if unspecific 44 if address.Port == 0 { 45 address.Port = net.Port(53) 46 } 47 48 s := &ClassicNameServer{ 49 address: address, 50 ips: make(map[string]record), 51 requests: make(map[uint16]dnsRequest), 52 pub: pubsub.NewService(), 53 name: strings.ToUpper(address.String()), 54 } 55 s.cleanup = &task.Periodic{ 56 Interval: time.Minute, 57 Execute: s.Cleanup, 58 } 59 s.udpServer = udp.NewSplitDispatcher(dispatcher, s.HandleResponse) 60 newError("DNS: created UDP client initialized for ", address.NetAddr()).AtInfo().WriteToLog() 61 return s 62 } 63 64 // Name implements Server. 65 func (s *ClassicNameServer) Name() string { 66 return s.name 67 } 68 69 // Cleanup clears expired items from cache 70 func (s *ClassicNameServer) Cleanup() error { 71 now := time.Now() 72 s.Lock() 73 defer s.Unlock() 74 75 if len(s.ips) == 0 && len(s.requests) == 0 { 76 return newError(s.name, " nothing to do. stopping...") 77 } 78 79 for domain, record := range s.ips { 80 if record.A != nil && record.A.Expire.Before(now) { 81 record.A = nil 82 } 83 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 84 record.AAAA = nil 85 } 86 87 if record.A == nil && record.AAAA == nil { 88 delete(s.ips, domain) 89 } else { 90 s.ips[domain] = record 91 } 92 } 93 94 if len(s.ips) == 0 { 95 s.ips = make(map[string]record) 96 } 97 98 for id, req := range s.requests { 99 if req.expire.Before(now) { 100 delete(s.requests, id) 101 } 102 } 103 104 if len(s.requests) == 0 { 105 s.requests = make(map[uint16]dnsRequest) 106 } 107 108 return nil 109 } 110 111 // HandleResponse handles udp response packet from remote DNS server. 112 func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) { 113 ipRec, err := parseResponse(packet.Payload.Bytes()) 114 if err != nil { 115 newError(s.name, " fail to parse responded DNS udp").AtError().WriteToLog() 116 return 117 } 118 119 s.Lock() 120 id := ipRec.ReqID 121 req, ok := s.requests[id] 122 if ok { 123 // remove the pending request 124 delete(s.requests, id) 125 } 126 s.Unlock() 127 if !ok { 128 newError(s.name, " cannot find the pending request").AtError().WriteToLog() 129 return 130 } 131 132 var rec record 133 switch req.reqType { 134 case dnsmessage.TypeA: 135 rec.A = ipRec 136 case dnsmessage.TypeAAAA: 137 rec.AAAA = ipRec 138 } 139 140 elapsed := time.Since(req.start) 141 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 142 if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) { 143 s.updateIP(req.domain, rec) 144 } 145 } 146 147 func (s *ClassicNameServer) updateIP(domain string, newRec record) { 148 s.Lock() 149 150 newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog() 151 rec := s.ips[domain] 152 153 updated := false 154 if isNewer(rec.A, newRec.A) { 155 rec.A = newRec.A 156 updated = true 157 } 158 if isNewer(rec.AAAA, newRec.AAAA) { 159 rec.AAAA = newRec.AAAA 160 updated = true 161 } 162 163 if updated { 164 s.ips[domain] = rec 165 } 166 if newRec.A != nil { 167 s.pub.Publish(domain+"4", nil) 168 } 169 if newRec.AAAA != nil { 170 s.pub.Publish(domain+"6", nil) 171 } 172 s.Unlock() 173 common.Must(s.cleanup.Start()) 174 } 175 176 func (s *ClassicNameServer) newReqID() uint16 { 177 return uint16(atomic.AddUint32(&s.reqID, 1)) 178 } 179 180 func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) { 181 s.Lock() 182 defer s.Unlock() 183 184 id := req.msg.ID 185 req.expire = time.Now().Add(time.Second * 8) 186 s.requests[id] = *req 187 } 188 189 func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 190 newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 191 192 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 193 194 for _, req := range reqs { 195 s.addPendingRequest(req) 196 b, _ := dns.PackMessage(req.msg) 197 udpCtx := core.ToBackgroundDetachedContext(ctx) 198 if inbound := session.InboundFromContext(ctx); inbound != nil { 199 udpCtx = session.ContextWithInbound(udpCtx, inbound) 200 } 201 udpCtx = session.ContextWithContent(udpCtx, &session.Content{ 202 Protocol: "dns", 203 }) 204 s.udpServer.Dispatch(udpCtx, s.address, b) 205 } 206 } 207 208 func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 209 s.RLock() 210 record, found := s.ips[domain] 211 s.RUnlock() 212 213 if !found { 214 return nil, errRecordNotFound 215 } 216 217 var ips []net.Address 218 var lastErr error 219 if option.IPv4Enable { 220 a, err := record.A.getIPs() 221 if err != nil { 222 lastErr = err 223 } 224 ips = append(ips, a...) 225 } 226 227 if option.IPv6Enable { 228 aaaa, err := record.AAAA.getIPs() 229 if err != nil { 230 lastErr = err 231 } 232 ips = append(ips, aaaa...) 233 } 234 235 if len(ips) > 0 { 236 return toNetIP(ips) 237 } 238 239 if lastErr != nil { 240 return nil, lastErr 241 } 242 243 return nil, dns_feature.ErrEmptyResponse 244 } 245 246 // QueryIP implements Server. 247 func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { 248 fqdn := Fqdn(domain) 249 250 if disableCache { 251 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 252 } else { 253 ips, err := s.findIPsForDomain(fqdn, option) 254 if err != errRecordNotFound { 255 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 256 return ips, err 257 } 258 } 259 260 // ipv4 and ipv6 belong to different subscription groups 261 var sub4, sub6 *pubsub.Subscriber 262 if option.IPv4Enable { 263 sub4 = s.pub.Subscribe(fqdn + "4") 264 defer sub4.Close() 265 } 266 if option.IPv6Enable { 267 sub6 = s.pub.Subscribe(fqdn + "6") 268 defer sub6.Close() 269 } 270 done := make(chan interface{}) 271 go func() { 272 if sub4 != nil { 273 select { 274 case <-sub4.Wait(): 275 case <-ctx.Done(): 276 } 277 } 278 if sub6 != nil { 279 select { 280 case <-sub6.Wait(): 281 case <-ctx.Done(): 282 } 283 } 284 close(done) 285 }() 286 s.sendQuery(ctx, fqdn, clientIP, option) 287 288 for { 289 ips, err := s.findIPsForDomain(fqdn, option) 290 if err != errRecordNotFound { 291 return ips, err 292 } 293 294 select { 295 case <-ctx.Done(): 296 return nil, ctx.Err() 297 case <-done: 298 } 299 } 300 }