github.com/eagleql/xray-core@v1.4.4/app/dns/udpns.go (about) 1 package dns 2 3 import ( 4 "context" 5 "github.com/eagleql/xray-core/transport/internet" 6 "strings" 7 "sync" 8 "sync/atomic" 9 "time" 10 11 "github.com/eagleql/xray-core/common" 12 "github.com/eagleql/xray-core/common/log" 13 "github.com/eagleql/xray-core/common/net" 14 "github.com/eagleql/xray-core/common/protocol/dns" 15 udp_proto "github.com/eagleql/xray-core/common/protocol/udp" 16 "github.com/eagleql/xray-core/common/session" 17 "github.com/eagleql/xray-core/common/signal/pubsub" 18 "github.com/eagleql/xray-core/common/task" 19 dns_feature "github.com/eagleql/xray-core/features/dns" 20 "github.com/eagleql/xray-core/features/routing" 21 "github.com/eagleql/xray-core/transport/internet/udp" 22 "golang.org/x/net/dns/dnsmessage" 23 ) 24 25 type ClassicNameServer struct { 26 sync.RWMutex 27 name string 28 address net.Destination 29 ips map[string]record 30 requests map[uint16]dnsRequest 31 pub *pubsub.Service 32 udpServer *udp.Dispatcher 33 cleanup *task.Periodic 34 reqID uint32 35 clientIP net.IP 36 } 37 38 func NewClassicNameServer(address net.Destination, dispatcher routing.Dispatcher, clientIP net.IP) *ClassicNameServer { 39 // default to 53 if unspecific 40 if address.Port == 0 { 41 address.Port = net.Port(53) 42 } 43 44 s := &ClassicNameServer{ 45 address: address, 46 ips: make(map[string]record), 47 requests: make(map[uint16]dnsRequest), 48 clientIP: clientIP, 49 pub: pubsub.NewService(), 50 name: strings.ToUpper(address.String()), 51 } 52 s.cleanup = &task.Periodic{ 53 Interval: time.Minute, 54 Execute: s.Cleanup, 55 } 56 s.udpServer = udp.NewDispatcher(dispatcher, s.HandleResponse) 57 newError("DNS: created UDP client initialized for ", address.NetAddr()).AtInfo().WriteToLog() 58 return s 59 } 60 61 func (s *ClassicNameServer) Name() string { 62 return s.name 63 } 64 65 func (s *ClassicNameServer) Cleanup() error { 66 now := time.Now() 67 s.Lock() 68 defer s.Unlock() 69 70 if len(s.ips) == 0 && len(s.requests) == 0 { 71 return newError(s.name, " nothing to do. stopping...") 72 } 73 74 for domain, record := range s.ips { 75 if record.A != nil && record.A.Expire.Before(now) { 76 record.A = nil 77 } 78 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 79 record.AAAA = nil 80 } 81 82 if record.A == nil && record.AAAA == nil { 83 delete(s.ips, domain) 84 } else { 85 s.ips[domain] = record 86 } 87 } 88 89 if len(s.ips) == 0 { 90 s.ips = make(map[string]record) 91 } 92 93 for id, req := range s.requests { 94 if req.expire.Before(now) { 95 delete(s.requests, id) 96 } 97 } 98 99 if len(s.requests) == 0 { 100 s.requests = make(map[uint16]dnsRequest) 101 } 102 103 return nil 104 } 105 106 func (s *ClassicNameServer) HandleResponse(ctx context.Context, packet *udp_proto.Packet) { 107 ipRec, err := parseResponse(packet.Payload.Bytes()) 108 if err != nil { 109 newError(s.name, " fail to parse responded DNS udp").AtError().WriteToLog() 110 return 111 } 112 113 s.Lock() 114 id := ipRec.ReqID 115 req, ok := s.requests[id] 116 if ok { 117 // remove the pending request 118 delete(s.requests, id) 119 } 120 s.Unlock() 121 if !ok { 122 newError(s.name, " cannot find the pending request").AtError().WriteToLog() 123 return 124 } 125 126 var rec record 127 switch req.reqType { 128 case dnsmessage.TypeA: 129 rec.A = ipRec 130 case dnsmessage.TypeAAAA: 131 rec.AAAA = ipRec 132 } 133 134 elapsed := time.Since(req.start) 135 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 136 if len(req.domain) > 0 && (rec.A != nil || rec.AAAA != nil) { 137 s.updateIP(req.domain, rec) 138 } 139 } 140 141 func (s *ClassicNameServer) updateIP(domain string, newRec record) { 142 s.Lock() 143 144 newError(s.name, " updating IP records for domain:", domain).AtDebug().WriteToLog() 145 rec := s.ips[domain] 146 147 updated := false 148 if isNewer(rec.A, newRec.A) { 149 rec.A = newRec.A 150 updated = true 151 } 152 if isNewer(rec.AAAA, newRec.AAAA) { 153 rec.AAAA = newRec.AAAA 154 updated = true 155 } 156 157 if updated { 158 s.ips[domain] = rec 159 } 160 if newRec.A != nil { 161 s.pub.Publish(domain+"4", nil) 162 } 163 if newRec.AAAA != nil { 164 s.pub.Publish(domain+"6", nil) 165 } 166 s.Unlock() 167 common.Must(s.cleanup.Start()) 168 } 169 170 func (s *ClassicNameServer) newReqID() uint16 { 171 return uint16(atomic.AddUint32(&s.reqID, 1)) 172 } 173 174 func (s *ClassicNameServer) addPendingRequest(req *dnsRequest) { 175 s.Lock() 176 defer s.Unlock() 177 178 id := req.msg.ID 179 req.expire = time.Now().Add(time.Second * 8) 180 s.requests[id] = *req 181 } 182 183 func (s *ClassicNameServer) sendQuery(ctx context.Context, domain string, option dns_feature.IPOption) { 184 newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 185 186 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(s.clientIP)) 187 188 for _, req := range reqs { 189 s.addPendingRequest(req) 190 b, _ := dns.PackMessage(req.msg) 191 udpCtx := context.Background() 192 if inbound := session.InboundFromContext(ctx); inbound != nil { 193 udpCtx = session.ContextWithInbound(udpCtx, inbound) 194 } 195 udpCtx = internet.ContextWithLookupDomain(udpCtx, internet.LookupDomainFromContext(ctx)) 196 udpCtx = session.ContextWithContent(udpCtx, &session.Content{ 197 Protocol: "dns", 198 }) 199 udpCtx = log.ContextWithAccessMessage(udpCtx, &log.AccessMessage{ 200 From: "DNS", 201 To: s.address, 202 Status: log.AccessAccepted, 203 Reason: "", 204 }) 205 s.udpServer.Dispatch(udpCtx, s.address, b) 206 } 207 } 208 209 func (s *ClassicNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 210 s.RLock() 211 record, found := s.ips[domain] 212 s.RUnlock() 213 214 if !found { 215 return nil, errRecordNotFound 216 } 217 218 var ips []net.Address 219 var lastErr error 220 if option.IPv4Enable { 221 a, err := record.A.getIPs() 222 if err != nil { 223 lastErr = err 224 } 225 ips = append(ips, a...) 226 } 227 228 if option.IPv6Enable { 229 aaaa, err := record.AAAA.getIPs() 230 if err != nil { 231 lastErr = err 232 } 233 ips = append(ips, aaaa...) 234 } 235 236 if len(ips) > 0 { 237 return toNetIP(ips), nil 238 } 239 240 if lastErr != nil { 241 return nil, lastErr 242 } 243 244 return nil, dns_feature.ErrEmptyResponse 245 } 246 247 // QueryIP implements Server. 248 func (s *ClassicNameServer) QueryIP(ctx context.Context, domain string, option dns_feature.IPOption) ([]net.IP, error) { 249 fqdn := Fqdn(domain) 250 251 ips, err := s.findIPsForDomain(fqdn, option) 252 if err != errRecordNotFound { 253 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 254 log.Record(&log.DNSLog{s.name, domain, ips, log.DNSCacheHit, 0, err}) 255 return ips, err 256 } 257 258 // ipv4 and ipv6 belong to different subscription groups 259 var sub4, sub6 *pubsub.Subscriber 260 if option.IPv4Enable { 261 sub4 = s.pub.Subscribe(fqdn + "4") 262 defer sub4.Close() 263 } 264 if option.IPv6Enable { 265 sub6 = s.pub.Subscribe(fqdn + "6") 266 defer sub6.Close() 267 } 268 done := make(chan interface{}) 269 go func() { 270 if sub4 != nil { 271 select { 272 case <-sub4.Wait(): 273 case <-ctx.Done(): 274 } 275 } 276 if sub6 != nil { 277 select { 278 case <-sub6.Wait(): 279 case <-ctx.Done(): 280 } 281 } 282 close(done) 283 }() 284 s.sendQuery(ctx, fqdn, option) 285 start := time.Now() 286 287 for { 288 ips, err := s.findIPsForDomain(fqdn, option) 289 if err != errRecordNotFound { 290 log.Record(&log.DNSLog{s.name, domain, ips, log.DNSQueried, time.Since(start), err}) 291 return ips, err 292 } 293 294 select { 295 case <-ctx.Done(): 296 return nil, ctx.Err() 297 case <-done: 298 } 299 } 300 }