github.com/v2fly/v2ray-core/v4@v4.45.2/app/dns/nameserver_tcp.go (about) 1 //go:build !confonly 2 // +build !confonly 3 4 package dns 5 6 import ( 7 "bytes" 8 "context" 9 "encoding/binary" 10 "net/url" 11 "sync" 12 "sync/atomic" 13 "time" 14 15 "golang.org/x/net/dns/dnsmessage" 16 17 "github.com/v2fly/v2ray-core/v4/common" 18 "github.com/v2fly/v2ray-core/v4/common/buf" 19 "github.com/v2fly/v2ray-core/v4/common/net" 20 "github.com/v2fly/v2ray-core/v4/common/protocol/dns" 21 "github.com/v2fly/v2ray-core/v4/common/session" 22 "github.com/v2fly/v2ray-core/v4/common/signal/pubsub" 23 "github.com/v2fly/v2ray-core/v4/common/task" 24 dns_feature "github.com/v2fly/v2ray-core/v4/features/dns" 25 "github.com/v2fly/v2ray-core/v4/features/routing" 26 "github.com/v2fly/v2ray-core/v4/transport/internet" 27 ) 28 29 // TCPNameServer implemented DNS over TCP (RFC7766). 30 type TCPNameServer struct { 31 sync.RWMutex 32 name string 33 destination *net.Destination 34 ips map[string]*record 35 pub *pubsub.Service 36 cleanup *task.Periodic 37 reqID uint32 38 dial func(context.Context) (net.Conn, error) 39 } 40 41 // NewTCPNameServer creates DNS over TCP server object for remote resolving. 42 func NewTCPNameServer(url *url.URL, dispatcher routing.Dispatcher) (*TCPNameServer, error) { 43 s, err := baseTCPNameServer(url, "TCP") 44 if err != nil { 45 return nil, err 46 } 47 48 s.dial = func(ctx context.Context) (net.Conn, error) { 49 link, err := dispatcher.Dispatch(ctx, *s.destination) 50 if err != nil { 51 return nil, err 52 } 53 54 return net.NewConnection( 55 net.ConnectionInputMulti(link.Writer), 56 net.ConnectionOutputMulti(link.Reader), 57 ), nil 58 } 59 60 return s, nil 61 } 62 63 // NewTCPLocalNameServer creates DNS over TCP client object for local resolving 64 func NewTCPLocalNameServer(url *url.URL) (*TCPNameServer, error) { 65 s, err := baseTCPNameServer(url, "TCPL") 66 if err != nil { 67 return nil, err 68 } 69 70 s.dial = func(ctx context.Context) (net.Conn, error) { 71 return internet.DialSystem(ctx, *s.destination, nil) 72 } 73 74 return s, nil 75 } 76 77 func baseTCPNameServer(url *url.URL, prefix string) (*TCPNameServer, error) { 78 var err error 79 port := net.Port(53) 80 if url.Port() != "" { 81 port, err = net.PortFromString(url.Port()) 82 if err != nil { 83 return nil, err 84 } 85 } 86 dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port) 87 88 s := &TCPNameServer{ 89 destination: &dest, 90 ips: make(map[string]*record), 91 pub: pubsub.NewService(), 92 name: prefix + "//" + dest.NetAddr(), 93 } 94 s.cleanup = &task.Periodic{ 95 Interval: time.Minute, 96 Execute: s.Cleanup, 97 } 98 99 return s, nil 100 } 101 102 // Name implements Server. 103 func (s *TCPNameServer) Name() string { 104 return s.name 105 } 106 107 // Cleanup clears expired items from cache 108 func (s *TCPNameServer) Cleanup() error { 109 now := time.Now() 110 s.Lock() 111 defer s.Unlock() 112 113 if len(s.ips) == 0 { 114 return newError("nothing to do. stopping...") 115 } 116 117 for domain, record := range s.ips { 118 if record.A != nil && record.A.Expire.Before(now) { 119 record.A = nil 120 } 121 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 122 record.AAAA = nil 123 } 124 125 if record.A == nil && record.AAAA == nil { 126 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 127 delete(s.ips, domain) 128 } else { 129 s.ips[domain] = record 130 } 131 } 132 133 if len(s.ips) == 0 { 134 s.ips = make(map[string]*record) 135 } 136 137 return nil 138 } 139 140 func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 141 elapsed := time.Since(req.start) 142 143 s.Lock() 144 rec, found := s.ips[req.domain] 145 if !found { 146 rec = &record{} 147 } 148 updated := false 149 150 switch req.reqType { 151 case dnsmessage.TypeA: 152 if isNewer(rec.A, ipRec) { 153 rec.A = ipRec 154 updated = true 155 } 156 case dnsmessage.TypeAAAA: 157 addr := make([]net.Address, 0) 158 for _, ip := range ipRec.IP { 159 if len(ip.IP()) == net.IPv6len { 160 addr = append(addr, ip) 161 } 162 } 163 ipRec.IP = addr 164 if isNewer(rec.AAAA, ipRec) { 165 rec.AAAA = ipRec 166 updated = true 167 } 168 } 169 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 170 171 if updated { 172 s.ips[req.domain] = rec 173 } 174 switch req.reqType { 175 case dnsmessage.TypeA: 176 s.pub.Publish(req.domain+"4", nil) 177 case dnsmessage.TypeAAAA: 178 s.pub.Publish(req.domain+"6", nil) 179 } 180 s.Unlock() 181 common.Must(s.cleanup.Start()) 182 } 183 184 func (s *TCPNameServer) newReqID() uint16 { 185 return uint16(atomic.AddUint32(&s.reqID, 1)) 186 } 187 188 func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 189 newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 190 191 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 192 193 var deadline time.Time 194 if d, ok := ctx.Deadline(); ok { 195 deadline = d 196 } else { 197 deadline = time.Now().Add(time.Second * 5) 198 } 199 200 for _, req := range reqs { 201 go func(r *dnsRequest) { 202 dnsCtx := ctx 203 204 if inbound := session.InboundFromContext(ctx); inbound != nil { 205 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 206 } 207 208 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 209 Protocol: "dns", 210 SkipDNSResolve: true, 211 }) 212 213 var cancel context.CancelFunc 214 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 215 defer cancel() 216 217 b, err := dns.PackMessage(r.msg) 218 if err != nil { 219 newError("failed to pack dns query").Base(err).AtError().WriteToLog() 220 return 221 } 222 223 conn, err := s.dial(dnsCtx) 224 if err != nil { 225 newError("failed to dial namesever").Base(err).AtError().WriteToLog() 226 return 227 } 228 defer conn.Close() 229 dnsReqBuf := buf.New() 230 binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len())) 231 dnsReqBuf.Write(b.Bytes()) 232 b.Release() 233 234 _, err = conn.Write(dnsReqBuf.Bytes()) 235 if err != nil { 236 newError("failed to send query").Base(err).AtError().WriteToLog() 237 return 238 } 239 dnsReqBuf.Release() 240 241 respBuf := buf.New() 242 defer respBuf.Release() 243 n, err := respBuf.ReadFullFrom(conn, 2) 244 if err != nil && n == 0 { 245 newError("failed to read response length").Base(err).AtError().WriteToLog() 246 return 247 } 248 var length int16 249 err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length) 250 if err != nil { 251 newError("failed to parse response length").Base(err).AtError().WriteToLog() 252 return 253 } 254 respBuf.Clear() 255 n, err = respBuf.ReadFullFrom(conn, int32(length)) 256 if err != nil && n == 0 { 257 newError("failed to read response length").Base(err).AtError().WriteToLog() 258 return 259 } 260 261 rec, err := parseResponse(respBuf.Bytes()) 262 if err != nil { 263 newError("failed to parse DNS over TCP response").Base(err).AtError().WriteToLog() 264 return 265 } 266 267 s.updateIP(r, rec) 268 }(req) 269 } 270 } 271 272 func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 273 s.RLock() 274 record, found := s.ips[domain] 275 s.RUnlock() 276 277 if !found { 278 return nil, errRecordNotFound 279 } 280 281 var err4 error 282 var err6 error 283 var ips []net.Address 284 var ip6 []net.Address 285 286 if option.IPv4Enable { 287 ips, err4 = record.A.getIPs() 288 } 289 290 if option.IPv6Enable { 291 ip6, err6 = record.AAAA.getIPs() 292 ips = append(ips, ip6...) 293 } 294 295 if len(ips) > 0 { 296 return toNetIP(ips) 297 } 298 299 if err4 != nil { 300 return nil, err4 301 } 302 303 if err6 != nil { 304 return nil, err6 305 } 306 307 return nil, dns_feature.ErrEmptyResponse 308 } 309 310 // QueryIP implements Server. 311 func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { 312 fqdn := Fqdn(domain) 313 314 if disableCache { 315 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 316 } else { 317 ips, err := s.findIPsForDomain(fqdn, option) 318 if err != errRecordNotFound { 319 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 320 return ips, err 321 } 322 } 323 324 // ipv4 and ipv6 belong to different subscription groups 325 var sub4, sub6 *pubsub.Subscriber 326 if option.IPv4Enable { 327 sub4 = s.pub.Subscribe(fqdn + "4") 328 defer sub4.Close() 329 } 330 if option.IPv6Enable { 331 sub6 = s.pub.Subscribe(fqdn + "6") 332 defer sub6.Close() 333 } 334 done := make(chan interface{}) 335 go func() { 336 if sub4 != nil { 337 select { 338 case <-sub4.Wait(): 339 case <-ctx.Done(): 340 } 341 } 342 if sub6 != nil { 343 select { 344 case <-sub6.Wait(): 345 case <-ctx.Done(): 346 } 347 } 348 close(done) 349 }() 350 s.sendQuery(ctx, fqdn, clientIP, option) 351 352 for { 353 ips, err := s.findIPsForDomain(fqdn, option) 354 if err != errRecordNotFound { 355 return ips, err 356 } 357 358 select { 359 case <-ctx.Done(): 360 return nil, ctx.Err() 361 case <-done: 362 } 363 } 364 }