github.com/xraypb/xray-core@v1.6.6/app/dns/nameserver_tcp.go (about) 1 package dns 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/binary" 7 "net/url" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/xraypb/xray-core/common" 13 "github.com/xraypb/xray-core/common/buf" 14 "github.com/xraypb/xray-core/common/log" 15 "github.com/xraypb/xray-core/common/net" 16 "github.com/xraypb/xray-core/common/net/cnc" 17 "github.com/xraypb/xray-core/common/protocol/dns" 18 "github.com/xraypb/xray-core/common/session" 19 "github.com/xraypb/xray-core/common/signal/pubsub" 20 "github.com/xraypb/xray-core/common/task" 21 dns_feature "github.com/xraypb/xray-core/features/dns" 22 "github.com/xraypb/xray-core/features/routing" 23 "github.com/xraypb/xray-core/transport/internet" 24 "golang.org/x/net/dns/dnsmessage" 25 ) 26 27 // TCPNameServer implemented DNS over TCP (RFC7766). 28 type TCPNameServer struct { 29 sync.RWMutex 30 name string 31 destination *net.Destination 32 ips map[string]*record 33 pub *pubsub.Service 34 cleanup *task.Periodic 35 reqID uint32 36 dial func(context.Context) (net.Conn, error) 37 } 38 39 // NewTCPNameServer creates DNS over TCP server object for remote resolving. 40 func NewTCPNameServer(url *url.URL, dispatcher routing.Dispatcher) (*TCPNameServer, error) { 41 s, err := baseTCPNameServer(url, "TCP") 42 if err != nil { 43 return nil, err 44 } 45 46 s.dial = func(ctx context.Context) (net.Conn, error) { 47 link, err := dispatcher.Dispatch(toDnsContext(ctx, s.destination.String()), *s.destination) 48 if err != nil { 49 return nil, err 50 } 51 52 return cnc.NewConnection( 53 cnc.ConnectionInputMulti(link.Writer), 54 cnc.ConnectionOutputMulti(link.Reader), 55 ), nil 56 } 57 58 return s, nil 59 } 60 61 // NewTCPLocalNameServer creates DNS over TCP client object for local resolving 62 func NewTCPLocalNameServer(url *url.URL) (*TCPNameServer, error) { 63 s, err := baseTCPNameServer(url, "TCPL") 64 if err != nil { 65 return nil, err 66 } 67 68 s.dial = func(ctx context.Context) (net.Conn, error) { 69 return internet.DialSystem(ctx, *s.destination, nil) 70 } 71 72 return s, nil 73 } 74 75 func baseTCPNameServer(url *url.URL, prefix string) (*TCPNameServer, error) { 76 var err error 77 port := net.Port(53) 78 if url.Port() != "" { 79 port, err = net.PortFromString(url.Port()) 80 if err != nil { 81 return nil, err 82 } 83 } 84 dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port) 85 86 s := &TCPNameServer{ 87 destination: &dest, 88 ips: make(map[string]*record), 89 pub: pubsub.NewService(), 90 name: prefix + "//" + dest.NetAddr(), 91 } 92 s.cleanup = &task.Periodic{ 93 Interval: time.Minute, 94 Execute: s.Cleanup, 95 } 96 97 return s, nil 98 } 99 100 // Name implements Server. 101 func (s *TCPNameServer) Name() string { 102 return s.name 103 } 104 105 // Cleanup clears expired items from cache 106 func (s *TCPNameServer) Cleanup() error { 107 now := time.Now() 108 s.Lock() 109 defer s.Unlock() 110 111 if len(s.ips) == 0 { 112 return newError("nothing to do. stopping...") 113 } 114 115 for domain, record := range s.ips { 116 if record.A != nil && record.A.Expire.Before(now) { 117 record.A = nil 118 } 119 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 120 record.AAAA = nil 121 } 122 123 if record.A == nil && record.AAAA == nil { 124 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 125 delete(s.ips, domain) 126 } else { 127 s.ips[domain] = record 128 } 129 } 130 131 if len(s.ips) == 0 { 132 s.ips = make(map[string]*record) 133 } 134 135 return nil 136 } 137 138 func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 139 elapsed := time.Since(req.start) 140 141 s.Lock() 142 rec, found := s.ips[req.domain] 143 if !found { 144 rec = &record{} 145 } 146 updated := false 147 148 switch req.reqType { 149 case dnsmessage.TypeA: 150 if isNewer(rec.A, ipRec) { 151 rec.A = ipRec 152 updated = true 153 } 154 case dnsmessage.TypeAAAA: 155 addr := make([]net.Address, 0) 156 for _, ip := range ipRec.IP { 157 if len(ip.IP()) == net.IPv6len { 158 addr = append(addr, ip) 159 } 160 } 161 ipRec.IP = addr 162 if isNewer(rec.AAAA, ipRec) { 163 rec.AAAA = ipRec 164 updated = true 165 } 166 } 167 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 168 169 if updated { 170 s.ips[req.domain] = rec 171 } 172 switch req.reqType { 173 case dnsmessage.TypeA: 174 s.pub.Publish(req.domain+"4", nil) 175 case dnsmessage.TypeAAAA: 176 s.pub.Publish(req.domain+"6", nil) 177 } 178 s.Unlock() 179 common.Must(s.cleanup.Start()) 180 } 181 182 func (s *TCPNameServer) newReqID() uint16 { 183 return uint16(atomic.AddUint32(&s.reqID, 1)) 184 } 185 186 func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 187 newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 188 189 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 190 191 var deadline time.Time 192 if d, ok := ctx.Deadline(); ok { 193 deadline = d 194 } else { 195 deadline = time.Now().Add(time.Second * 5) 196 } 197 198 for _, req := range reqs { 199 go func(r *dnsRequest) { 200 dnsCtx := ctx 201 202 if inbound := session.InboundFromContext(ctx); inbound != nil { 203 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 204 } 205 206 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 207 Protocol: "dns", 208 SkipDNSResolve: true, 209 }) 210 211 var cancel context.CancelFunc 212 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 213 defer cancel() 214 215 b, err := dns.PackMessage(r.msg) 216 if err != nil { 217 newError("failed to pack dns query").Base(err).AtError().WriteToLog() 218 return 219 } 220 221 conn, err := s.dial(dnsCtx) 222 if err != nil { 223 newError("failed to dial namesever").Base(err).AtError().WriteToLog() 224 return 225 } 226 defer conn.Close() 227 dnsReqBuf := buf.New() 228 binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len())) 229 dnsReqBuf.Write(b.Bytes()) 230 b.Release() 231 232 _, err = conn.Write(dnsReqBuf.Bytes()) 233 if err != nil { 234 newError("failed to send query").Base(err).AtError().WriteToLog() 235 return 236 } 237 dnsReqBuf.Release() 238 239 respBuf := buf.New() 240 defer respBuf.Release() 241 n, err := respBuf.ReadFullFrom(conn, 2) 242 if err != nil && n == 0 { 243 newError("failed to read response length").Base(err).AtError().WriteToLog() 244 return 245 } 246 var length int16 247 err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length) 248 if err != nil { 249 newError("failed to parse response length").Base(err).AtError().WriteToLog() 250 return 251 } 252 respBuf.Clear() 253 n, err = respBuf.ReadFullFrom(conn, int32(length)) 254 if err != nil && n == 0 { 255 newError("failed to read response length").Base(err).AtError().WriteToLog() 256 return 257 } 258 259 rec, err := parseResponse(respBuf.Bytes()) 260 if err != nil { 261 newError("failed to parse DNS over TCP response").Base(err).AtError().WriteToLog() 262 return 263 } 264 265 s.updateIP(r, rec) 266 }(req) 267 } 268 } 269 270 func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 271 s.RLock() 272 record, found := s.ips[domain] 273 s.RUnlock() 274 275 if !found { 276 return nil, errRecordNotFound 277 } 278 279 var err4 error 280 var err6 error 281 var ips []net.Address 282 var ip6 []net.Address 283 284 if option.IPv4Enable { 285 ips, err4 = record.A.getIPs() 286 } 287 288 if option.IPv6Enable { 289 ip6, err6 = record.AAAA.getIPs() 290 ips = append(ips, ip6...) 291 } 292 293 if len(ips) > 0 { 294 return toNetIP(ips) 295 } 296 297 if err4 != nil { 298 return nil, err4 299 } 300 301 if err6 != nil { 302 return nil, err6 303 } 304 305 return nil, dns_feature.ErrEmptyResponse 306 } 307 308 // QueryIP implements Server. 309 func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { 310 fqdn := Fqdn(domain) 311 312 if disableCache { 313 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 314 } else { 315 ips, err := s.findIPsForDomain(fqdn, option) 316 if err != errRecordNotFound { 317 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 318 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) 319 return ips, err 320 } 321 } 322 323 // ipv4 and ipv6 belong to different subscription groups 324 var sub4, sub6 *pubsub.Subscriber 325 if option.IPv4Enable { 326 sub4 = s.pub.Subscribe(fqdn + "4") 327 defer sub4.Close() 328 } 329 if option.IPv6Enable { 330 sub6 = s.pub.Subscribe(fqdn + "6") 331 defer sub6.Close() 332 } 333 done := make(chan interface{}) 334 go func() { 335 if sub4 != nil { 336 select { 337 case <-sub4.Wait(): 338 case <-ctx.Done(): 339 } 340 } 341 if sub6 != nil { 342 select { 343 case <-sub6.Wait(): 344 case <-ctx.Done(): 345 } 346 } 347 close(done) 348 }() 349 s.sendQuery(ctx, fqdn, clientIP, option) 350 start := time.Now() 351 352 for { 353 ips, err := s.findIPsForDomain(fqdn, option) 354 if err != errRecordNotFound { 355 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) 356 return ips, err 357 } 358 359 select { 360 case <-ctx.Done(): 361 return nil, ctx.Err() 362 case <-done: 363 } 364 } 365 }