github.com/xmplusdev/xmcore@v1.8.11-0.20240412132628-5518b55526af/app/dns/nameserver_tcp.go (about) 1 package dns 2 3 import ( 4 "bytes" 5 "context" 6 "encoding/binary" 7 "net/url" 8 "sync" 9 "sync/atomic" 10 "time" 11 12 "github.com/xmplusdev/xmcore/common" 13 "github.com/xmplusdev/xmcore/common/buf" 14 "github.com/xmplusdev/xmcore/common/log" 15 "github.com/xmplusdev/xmcore/common/net" 16 "github.com/xmplusdev/xmcore/common/net/cnc" 17 "github.com/xmplusdev/xmcore/common/protocol/dns" 18 "github.com/xmplusdev/xmcore/common/session" 19 "github.com/xmplusdev/xmcore/common/signal/pubsub" 20 "github.com/xmplusdev/xmcore/common/task" 21 dns_feature "github.com/xmplusdev/xmcore/features/dns" 22 "github.com/xmplusdev/xmcore/features/routing" 23 "github.com/xmplusdev/xmcore/transport/internet" 24 "golang.org/x/net/dns/dnsmessage" 25 ) 26 27 // TCPNameServer implemented DNS over TCP (RFC7766). 28 type TCPNameServer struct { 29 sync.RWMutex 30 name string 31 destination *net.Destination 32 ips map[string]*record 33 pub *pubsub.Service 34 cleanup *task.Periodic 35 reqID uint32 36 dial func(context.Context) (net.Conn, error) 37 queryStrategy QueryStrategy 38 } 39 40 // NewTCPNameServer creates DNS over TCP server object for remote resolving. 41 func NewTCPNameServer( 42 url *url.URL, 43 dispatcher routing.Dispatcher, 44 queryStrategy QueryStrategy, 45 ) (*TCPNameServer, error) { 46 s, err := baseTCPNameServer(url, "TCP", queryStrategy) 47 if err != nil { 48 return nil, err 49 } 50 51 s.dial = func(ctx context.Context) (net.Conn, error) { 52 link, err := dispatcher.Dispatch(toDnsContext(ctx, s.destination.String()), *s.destination) 53 if err != nil { 54 return nil, err 55 } 56 57 return cnc.NewConnection( 58 cnc.ConnectionInputMulti(link.Writer), 59 cnc.ConnectionOutputMulti(link.Reader), 60 ), nil 61 } 62 63 return s, nil 64 } 65 66 // NewTCPLocalNameServer creates DNS over TCP client object for local resolving 67 func NewTCPLocalNameServer(url *url.URL, queryStrategy QueryStrategy) (*TCPNameServer, error) { 68 s, err := baseTCPNameServer(url, "TCPL", queryStrategy) 69 if err != nil { 70 return nil, err 71 } 72 73 s.dial = func(ctx context.Context) (net.Conn, error) { 74 return internet.DialSystem(ctx, *s.destination, nil) 75 } 76 77 return s, nil 78 } 79 80 func baseTCPNameServer(url *url.URL, prefix string, queryStrategy QueryStrategy) (*TCPNameServer, error) { 81 port := net.Port(53) 82 if url.Port() != "" { 83 var err error 84 if port, err = net.PortFromString(url.Port()); err != nil { 85 return nil, err 86 } 87 } 88 dest := net.TCPDestination(net.ParseAddress(url.Hostname()), port) 89 90 s := &TCPNameServer{ 91 destination: &dest, 92 ips: make(map[string]*record), 93 pub: pubsub.NewService(), 94 name: prefix + "//" + dest.NetAddr(), 95 queryStrategy: queryStrategy, 96 } 97 s.cleanup = &task.Periodic{ 98 Interval: time.Minute, 99 Execute: s.Cleanup, 100 } 101 102 return s, nil 103 } 104 105 // Name implements Server. 106 func (s *TCPNameServer) Name() string { 107 return s.name 108 } 109 110 // Cleanup clears expired items from cache 111 func (s *TCPNameServer) Cleanup() error { 112 now := time.Now() 113 s.Lock() 114 defer s.Unlock() 115 116 if len(s.ips) == 0 { 117 return newError("nothing to do. stopping...") 118 } 119 120 for domain, record := range s.ips { 121 if record.A != nil && record.A.Expire.Before(now) { 122 record.A = nil 123 } 124 if record.AAAA != nil && record.AAAA.Expire.Before(now) { 125 record.AAAA = nil 126 } 127 128 if record.A == nil && record.AAAA == nil { 129 newError(s.name, " cleanup ", domain).AtDebug().WriteToLog() 130 delete(s.ips, domain) 131 } else { 132 s.ips[domain] = record 133 } 134 } 135 136 if len(s.ips) == 0 { 137 s.ips = make(map[string]*record) 138 } 139 140 return nil 141 } 142 143 func (s *TCPNameServer) updateIP(req *dnsRequest, ipRec *IPRecord) { 144 elapsed := time.Since(req.start) 145 146 s.Lock() 147 rec, found := s.ips[req.domain] 148 if !found { 149 rec = &record{} 150 } 151 updated := false 152 153 switch req.reqType { 154 case dnsmessage.TypeA: 155 if isNewer(rec.A, ipRec) { 156 rec.A = ipRec 157 updated = true 158 } 159 case dnsmessage.TypeAAAA: 160 addr := make([]net.Address, 0) 161 for _, ip := range ipRec.IP { 162 if len(ip.IP()) == net.IPv6len { 163 addr = append(addr, ip) 164 } 165 } 166 ipRec.IP = addr 167 if isNewer(rec.AAAA, ipRec) { 168 rec.AAAA = ipRec 169 updated = true 170 } 171 } 172 newError(s.name, " got answer: ", req.domain, " ", req.reqType, " -> ", ipRec.IP, " ", elapsed).AtInfo().WriteToLog() 173 174 if updated { 175 s.ips[req.domain] = rec 176 } 177 switch req.reqType { 178 case dnsmessage.TypeA: 179 s.pub.Publish(req.domain+"4", nil) 180 case dnsmessage.TypeAAAA: 181 s.pub.Publish(req.domain+"6", nil) 182 } 183 s.Unlock() 184 common.Must(s.cleanup.Start()) 185 } 186 187 func (s *TCPNameServer) newReqID() uint16 { 188 return uint16(atomic.AddUint32(&s.reqID, 1)) 189 } 190 191 func (s *TCPNameServer) sendQuery(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption) { 192 newError(s.name, " querying DNS for: ", domain).AtDebug().WriteToLog(session.ExportIDToError(ctx)) 193 194 reqs := buildReqMsgs(domain, option, s.newReqID, genEDNS0Options(clientIP)) 195 196 var deadline time.Time 197 if d, ok := ctx.Deadline(); ok { 198 deadline = d 199 } else { 200 deadline = time.Now().Add(time.Second * 5) 201 } 202 203 for _, req := range reqs { 204 go func(r *dnsRequest) { 205 dnsCtx := ctx 206 207 if inbound := session.InboundFromContext(ctx); inbound != nil { 208 dnsCtx = session.ContextWithInbound(dnsCtx, inbound) 209 } 210 211 dnsCtx = session.ContextWithContent(dnsCtx, &session.Content{ 212 Protocol: "dns", 213 SkipDNSResolve: true, 214 }) 215 216 var cancel context.CancelFunc 217 dnsCtx, cancel = context.WithDeadline(dnsCtx, deadline) 218 defer cancel() 219 220 b, err := dns.PackMessage(r.msg) 221 if err != nil { 222 newError("failed to pack dns query").Base(err).AtError().WriteToLog() 223 return 224 } 225 226 conn, err := s.dial(dnsCtx) 227 if err != nil { 228 newError("failed to dial namesever").Base(err).AtError().WriteToLog() 229 return 230 } 231 defer conn.Close() 232 dnsReqBuf := buf.New() 233 binary.Write(dnsReqBuf, binary.BigEndian, uint16(b.Len())) 234 dnsReqBuf.Write(b.Bytes()) 235 b.Release() 236 237 _, err = conn.Write(dnsReqBuf.Bytes()) 238 if err != nil { 239 newError("failed to send query").Base(err).AtError().WriteToLog() 240 return 241 } 242 dnsReqBuf.Release() 243 244 respBuf := buf.New() 245 defer respBuf.Release() 246 n, err := respBuf.ReadFullFrom(conn, 2) 247 if err != nil && n == 0 { 248 newError("failed to read response length").Base(err).AtError().WriteToLog() 249 return 250 } 251 var length int16 252 err = binary.Read(bytes.NewReader(respBuf.Bytes()), binary.BigEndian, &length) 253 if err != nil { 254 newError("failed to parse response length").Base(err).AtError().WriteToLog() 255 return 256 } 257 respBuf.Clear() 258 n, err = respBuf.ReadFullFrom(conn, int32(length)) 259 if err != nil && n == 0 { 260 newError("failed to read response length").Base(err).AtError().WriteToLog() 261 return 262 } 263 264 rec, err := parseResponse(respBuf.Bytes()) 265 if err != nil { 266 newError("failed to parse DNS over TCP response").Base(err).AtError().WriteToLog() 267 return 268 } 269 270 s.updateIP(r, rec) 271 }(req) 272 } 273 } 274 275 func (s *TCPNameServer) findIPsForDomain(domain string, option dns_feature.IPOption) ([]net.IP, error) { 276 s.RLock() 277 record, found := s.ips[domain] 278 s.RUnlock() 279 280 if !found { 281 return nil, errRecordNotFound 282 } 283 284 var err4 error 285 var err6 error 286 var ips []net.Address 287 var ip6 []net.Address 288 289 if option.IPv4Enable { 290 ips, err4 = record.A.getIPs() 291 } 292 293 if option.IPv6Enable { 294 ip6, err6 = record.AAAA.getIPs() 295 ips = append(ips, ip6...) 296 } 297 298 if len(ips) > 0 { 299 return toNetIP(ips) 300 } 301 302 if err4 != nil { 303 return nil, err4 304 } 305 306 if err6 != nil { 307 return nil, err6 308 } 309 310 return nil, dns_feature.ErrEmptyResponse 311 } 312 313 // QueryIP implements Server. 314 func (s *TCPNameServer) QueryIP(ctx context.Context, domain string, clientIP net.IP, option dns_feature.IPOption, disableCache bool) ([]net.IP, error) { 315 fqdn := Fqdn(domain) 316 option = ResolveIpOptionOverride(s.queryStrategy, option) 317 if !option.IPv4Enable && !option.IPv6Enable { 318 return nil, dns_feature.ErrEmptyResponse 319 } 320 321 if disableCache { 322 newError("DNS cache is disabled. Querying IP for ", domain, " at ", s.name).AtDebug().WriteToLog() 323 } else { 324 ips, err := s.findIPsForDomain(fqdn, option) 325 if err != errRecordNotFound { 326 newError(s.name, " cache HIT ", domain, " -> ", ips).Base(err).AtDebug().WriteToLog() 327 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSCacheHit, Elapsed: 0, Error: err}) 328 return ips, err 329 } 330 } 331 332 // ipv4 and ipv6 belong to different subscription groups 333 var sub4, sub6 *pubsub.Subscriber 334 if option.IPv4Enable { 335 sub4 = s.pub.Subscribe(fqdn + "4") 336 defer sub4.Close() 337 } 338 if option.IPv6Enable { 339 sub6 = s.pub.Subscribe(fqdn + "6") 340 defer sub6.Close() 341 } 342 done := make(chan interface{}) 343 go func() { 344 if sub4 != nil { 345 select { 346 case <-sub4.Wait(): 347 case <-ctx.Done(): 348 } 349 } 350 if sub6 != nil { 351 select { 352 case <-sub6.Wait(): 353 case <-ctx.Done(): 354 } 355 } 356 close(done) 357 }() 358 s.sendQuery(ctx, fqdn, clientIP, option) 359 start := time.Now() 360 361 for { 362 ips, err := s.findIPsForDomain(fqdn, option) 363 if err != errRecordNotFound { 364 log.Record(&log.DNSLog{Server: s.name, Domain: domain, Result: ips, Status: log.DNSQueried, Elapsed: time.Since(start), Error: err}) 365 return ips, err 366 } 367 368 select { 369 case <-ctx.Done(): 370 return nil, ctx.Err() 371 case <-done: 372 } 373 } 374 }