github.com/imannamdari/v2ray-core/v5@v5.0.5/app/dns/nameserver_tcp.go (about) 1 //go:build !confonly 2 // +build !confonly 3 4 package dns 5 6 import ( 7 "bytes" 8 "context" 9 "encoding/binary" 10 "net/url" 11 "sync" 12 "sync/atomic" 13 "time" 14 15 "golang.org/x/net/dns/dnsmessage" 16 17 "github.com/imannamdari/v2ray-core/v5/common" 18 "github.com/imannamdari/v2ray-core/v5/common/buf" 19 "github.com/imannamdari/v2ray-core/v5/common/net" 20 "github.com/imannamdari/v2ray-core/v5/common/protocol/dns" 21 "github.com/imannamdari/v2ray-core/v5/common/session" 22 "github.com/imannamdari/v2ray-core/v5/common/signal/pubsub" 23 "github.com/imannamdari/v2ray-core/v5/common/task" 24 dns_feature "github.com/imannamdari/v2ray-core/v5/features/dns" 25 "github.com/imannamdari/v2ray-core/v5/features/routing" 26 "github.com/imannamdari/v2ray-core/v5/transport/internet" 27 ) 28 29 // TCPNameServer implemented DNS over TCP (RFC7766). 30 type TCPNameServer struct { 31 sync.RWMutex 32 name string 33 destination net.Destination 34 ips map[string]record 35 pub *pubsub.Service 36 cleanup *task.Periodic 37 reqID uint32 38 dial func(context.Context) (net.Conn, error) 39 } 40 41 // NewTCPNameServer creates DNS over TCP server object for remote resolving. 42 func NewTCPNameServer(url *url.URL, dispatcher routing.Dispatcher) (*TCPNameServer, error) { 43 s, err := baseTCPNameServer(url, "TCP") 44 if err != nil { 45 return nil, err 46 } 47 48 s.dial = func(ctx context.Context) (net.Conn, error) { 49 link, err := dispatcher.Dispatch(ctx, s.destination) 50 if err != nil { 51 return nil, err 52 } 53 54 return net.NewConnection( 55 net.ConnectionInputMulti(link.Writer), 56 net.ConnectionOutputMulti(link.Reader), 57 ), nil 58 } 59 60 return s, nil 61 } 62 63 // NewTCPLocalNameServer creates DNS over TCP client object for local resolving 64 func NewTCPLocalNameServer(url *url.URL) (*TCPNameServer, error) { 65 s, err := baseTCPNameServer(url, "TCPL") 66 if err != nil { 67 return nil, err 68 } 69 70 s.dial = func(ctx context.Context) (net.Conn, error) { 71 return internet.DialSystem(ctx, s.destination, nil) 72 } 73 74 return s, nil 75 } 76 77 func baseTCPNameServer(url *url.URL, prefix string) (*TCPNameServer, error) { 78 var err error 79 port := net.Port(53) 80 if url.Port() != "" { 81 port, err = net.PortFromString(url.Port()) 82 if err != nil { 83 return nil, err 84 } 85 } 86 dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port) 87 88 s := &TCPNameServer{ 89 destination: dest, 90 ips: make(map[string]record), 91 pub: pubsub.NewService(), 92 name: prefix + "//" + dest.NetAddr(), 93 } 94 s.cleanup = &task.Periodic{ 95 Interval: time.Minute, 96 Execute: s.Cleanup, 97 } 98 99 return s, nil 100 } 101 102 // Name implements Server. 103 func (s *TCPNameServer) Name() string { 104 return s.name 105 } 106 107 // Cleanup clears expired items from cache 108 func (s *TCPNameServer) Cleanup() error { 109 now := time.Now() 110 s.Lock() 111 defer s.Unlock() 112 113 if len(s.ips) == 0 { 114 return newError("nothing to do. stopping...") 115 } 116 117 for domain, record := range s.ips { 118 if record.A != nil && record.A.Expire.Before(now) { 119 record.A = nil 120 } 121 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 122 record.AAAA = nil 123 } 124 125 if record.A == nil && record.AAAA == nil { 126 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 127 delete(s.ips, domain) 128 } else { 129 s.ips[domain] = record 130 } 131 } 132 133 if len(s.ips) == 0 { 134 s.ips = make(map[string]record) 135 } 136 137 return nil 138 } 139 140 func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 141 elapsed := time.Since(req.start) 142 143 s.Lock() 144 rec := s.ips[req.domain] 145 updated := false 146 147 switch req.reqType { 148 case dnsmessage.TypeA: 149 if isNewer(rec.A, ipRec) { 150 rec.A = ipRec 151 updated = true 152 } 153 case dnsmessage.TypeAAAA: 154 addr := make([]net.Address, 0) 155 for _, ip := range ipRec.IP { 156 if len(ip.IP()) == net.IPv6len { 157 addr = append(addr, ip) 158 } 159 } 160 ipRec.IP = addr 161 if isNewer(rec.AAAA, ipRec) { 162 rec.AAAA = ipRec 163 updated = true 164 } 165 } 166 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 167 168 if updated { 169 s.ips[req.domain] = rec 170 } 171 switch req.reqType { 172 case dnsmessage.TypeA: 173 s.pub.Publish(req.domain+"4", nil) 174 case dnsmessage.TypeAAAA: 175 s.pub.Publish(req.domain+"6", nil) 176 } 177 s.Unlock() 178 common.Must(s.cleanup.Start()) 179 } 180 181 func (s *TCPNameServer) newReqID() uint16 { 182 return uint16(atomic.AddUint32(&s.reqID, 1)) 183 } 184 185 func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 186 newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 187 188 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 189 190 var deadline time.Time 191 if d, ok := ctx.Deadline(); ok { 192 deadline = d 193 } else { 194 deadline = time.Now().Add(time.Second * 5) 195 } 196 197 for _, req := range reqs { 198 go func(r *dnsRequest) { 199 dnsCtx := ctx 200 201 if inbound := session.InboundFromContext(ctx); inbound != nil { 202 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 203 } 204 205 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 206 Protocol: "dns", 207 SkipDNSResolve: true, 208 }) 209 210 var cancel context.CancelFunc 211 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 212 defer cancel() 213 214 b, err := dns.PackMessage(r.msg) 215 if err != nil { 216 newError("failed to pack dns query").Base(err).AtError().WriteToLog() 217 return 218 } 219 220 conn, err := s.dial(dnsCtx) 221 if err != nil { 222 newError("failed to dial namesever").Base(err).AtError().WriteToLog() 223 return 224 } 225 defer conn.Close() 226 dnsReqBuf := buf.New() 227 binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len())) 228 dnsReqBuf.Write(b.Bytes()) 229 b.Release() 230 231 _, err = conn.Write(dnsReqBuf.Bytes()) 232 if err != nil { 233 newError("failed to send query").Base(err).AtError().WriteToLog() 234 return 235 } 236 dnsReqBuf.Release() 237 238 respBuf := buf.New() 239 defer respBuf.Release() 240 n, err := respBuf.ReadFullFrom(conn, 2) 241 if err != nil && n == 0 { 242 newError("failed to read response length").Base(err).AtError().WriteToLog() 243 return 244 } 245 var length int16 246 err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length) 247 if err != nil { 248 newError("failed to parse response length").Base(err).AtError().WriteToLog() 249 return 250 } 251 respBuf.Clear() 252 n, err = respBuf.ReadFullFrom(conn, int32(length)) 253 if err != nil && n == 0 { 254 newError("failed to read response length").Base(err).AtError().WriteToLog() 255 return 256 } 257 258 rec, err := parseResponse(respBuf.Bytes()) 259 if err != nil { 260 newError("failed to parse DNS over TCP response").Base(err).AtError().WriteToLog() 261 return 262 } 263 264 s.updateIP(r, rec) 265 }(req) 266 } 267 } 268 269 func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 270 s.RLock() 271 record, found := s.ips[domain] 272 s.RUnlock() 273 274 if !found { 275 return nil, errRecordNotFound 276 } 277 278 var ips []net.Address 279 var lastErr error 280 if option.IPv4Enable { 281 a, err := record.A.getIPs() 282 if err != nil { 283 lastErr = err 284 } 285 ips = append(ips, a...) 286 } 287 288 if option.IPv6Enable { 289 aaaa, err := record.AAAA.getIPs() 290 if err != nil { 291 lastErr = err 292 } 293 ips = append(ips, aaaa...) 294 } 295 296 if len(ips) > 0 { 297 return toNetIP(ips) 298 } 299 300 if lastErr != nil { 301 return nil, lastErr 302 } 303 304 return nil, dns_feature.ErrEmptyResponse 305 } 306 307 // QueryIP implements Server. 308 func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { 309 fqdn := Fqdn(domain) 310 311 if disableCache { 312 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 313 } else { 314 ips, err := s.findIPsForDomain(fqdn, option) 315 if err != errRecordNotFound { 316 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 317 return ips, err 318 } 319 } 320 321 // ipv4 and ipv6 belong to different subscription groups 322 var sub4, sub6 *pubsub.Subscriber 323 if option.IPv4Enable { 324 sub4 = s.pub.Subscribe(fqdn + "4") 325 defer sub4.Close() 326 } 327 if option.IPv6Enable { 328 sub6 = s.pub.Subscribe(fqdn + "6") 329 defer sub6.Close() 330 } 331 done := make(chan interface{}) 332 go func() { 333 if sub4 != nil { 334 select { 335 case <-sub4.Wait(): 336 case <-ctx.Done(): 337 } 338 } 339 if sub6 != nil { 340 select { 341 case <-sub6.Wait(): 342 case <-ctx.Done(): 343 } 344 } 345 close(done) 346 }() 347 s.sendQuery(ctx, fqdn, clientIP, option) 348 349 for { 350 ips, err := s.findIPsForDomain(fqdn, option) 351 if err != errRecordNotFound { 352 return ips, err 353 } 354 355 select { 356 case <-ctx.Done(): 357 return nil, ctx.Err() 358 case <-done: 359 } 360 } 361 }