github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/dns/dns.go (about) 1 package dns 2 3 import ( 4 "context" 5 "io" 6 "sync" 7 "time" 8 9 "github.com/xtls/xray-core/common" 10 "github.com/xtls/xray-core/common/buf" 11 "github.com/xtls/xray-core/common/errors" 12 "github.com/xtls/xray-core/common/net" 13 dns_proto "github.com/xtls/xray-core/common/protocol/dns" 14 "github.com/xtls/xray-core/common/session" 15 "github.com/xtls/xray-core/common/signal" 16 "github.com/xtls/xray-core/common/task" 17 "github.com/xtls/xray-core/core" 18 "github.com/xtls/xray-core/features/dns" 19 "github.com/xtls/xray-core/features/policy" 20 "github.com/xtls/xray-core/transport" 21 "github.com/xtls/xray-core/transport/internet" 22 "github.com/xtls/xray-core/transport/internet/stat" 23 "golang.org/x/net/dns/dnsmessage" 24 ) 25 26 func init() { 27 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 28 h := new(Handler) 29 if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { 30 core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { 31 h.fdns = fdns 32 }) 33 return h.Init(config.(*Config), dnsClient, policyManager) 34 }); err != nil { 35 return nil, err 36 } 37 return h, nil 38 })) 39 } 40 41 type ownLinkVerifier interface { 42 IsOwnLink(ctx context.Context) bool 43 } 44 45 type Handler struct { 46 client dns.Client 47 fdns dns.FakeDNSEngine 48 ownLinkVerifier ownLinkVerifier 49 server net.Destination 50 timeout time.Duration 51 nonIPQuery string 52 } 53 54 func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error { 55 h.client = dnsClient 56 h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle 57 58 if v, ok := dnsClient.(ownLinkVerifier); ok { 59 h.ownLinkVerifier = v 60 } 61 62 if config.Server != nil { 63 h.server = config.Server.AsDestination() 64 } 65 h.nonIPQuery = config.Non_IPQuery 66 return nil 67 } 68 69 func (h *Handler) isOwnLink(ctx context.Context) bool { 70 return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) 71 } 72 73 func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { 74 var parser dnsmessage.Parser 75 header, err := parser.Start(b) 76 if err != nil { 77 newError("parser start").Base(err).WriteToLog() 78 return 79 } 80 81 id = header.ID 82 q, err := parser.Question() 83 if err != nil { 84 newError("question").Base(err).WriteToLog() 85 return 86 } 87 qType = q.Type 88 if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { 89 return 90 } 91 92 domain = q.Name.String() 93 r = true 94 return 95 } 96 97 // Process implements proxy.Outbound. 98 func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { 99 outbounds := session.OutboundsFromContext(ctx) 100 ob := outbounds[len(outbounds) - 1] 101 if !ob.Target.IsValid() { 102 return newError("invalid outbound") 103 } 104 ob.Name = "dns" 105 106 srcNetwork := ob.Target.Network 107 108 dest := ob.Target 109 if h.server.Network != net.Network_Unknown { 110 dest.Network = h.server.Network 111 } 112 if h.server.Address != nil { 113 dest.Address = h.server.Address 114 } 115 if h.server.Port != 0 { 116 dest.Port = h.server.Port 117 } 118 119 newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx)) 120 121 conn := &outboundConn{ 122 dialer: func() (stat.Connection, error) { 123 return d.Dial(ctx, dest) 124 }, 125 connReady: make(chan struct{}, 1), 126 } 127 128 var reader dns_proto.MessageReader 129 var writer dns_proto.MessageWriter 130 if srcNetwork == net.Network_TCP { 131 reader = dns_proto.NewTCPReader(link.Reader) 132 writer = &dns_proto.TCPWriter{ 133 Writer: link.Writer, 134 } 135 } else { 136 reader = &dns_proto.UDPReader{ 137 Reader: link.Reader, 138 } 139 writer = &dns_proto.UDPWriter{ 140 Writer: link.Writer, 141 } 142 } 143 144 var connReader dns_proto.MessageReader 145 var connWriter dns_proto.MessageWriter 146 if dest.Network == net.Network_TCP { 147 connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) 148 connWriter = &dns_proto.TCPWriter{ 149 Writer: buf.NewWriter(conn), 150 } 151 } else { 152 connReader = &dns_proto.UDPReader{ 153 Reader: buf.NewPacketReader(conn), 154 } 155 connWriter = &dns_proto.UDPWriter{ 156 Writer: buf.NewWriter(conn), 157 } 158 } 159 160 if session.TimeoutOnlyFromContext(ctx) { 161 ctx, _ = context.WithCancel(context.Background()) 162 } 163 164 ctx, cancel := context.WithCancel(ctx) 165 timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) 166 167 request := func() error { 168 defer conn.Close() 169 170 for { 171 b, err := reader.ReadMessage() 172 if err == io.EOF { 173 return nil 174 } 175 176 if err != nil { 177 return err 178 } 179 180 timer.Update() 181 182 if !h.isOwnLink(ctx) { 183 isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) 184 if isIPQuery { 185 go h.handleIPQuery(id, qType, domain, writer) 186 } 187 if isIPQuery || h.nonIPQuery == "drop" || qType == 65 { 188 b.Release() 189 continue 190 } 191 } 192 193 if err := connWriter.WriteMessage(b); err != nil { 194 return err 195 } 196 } 197 } 198 199 response := func() error { 200 for { 201 b, err := connReader.ReadMessage() 202 if err == io.EOF { 203 return nil 204 } 205 206 if err != nil { 207 return err 208 } 209 210 timer.Update() 211 212 if err := writer.WriteMessage(b); err != nil { 213 return err 214 } 215 } 216 } 217 218 if err := task.Run(ctx, request, response); err != nil { 219 return newError("connection ends").Base(err) 220 } 221 222 return nil 223 } 224 225 func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { 226 var ips []net.IP 227 var err error 228 229 var ttl uint32 = 600 230 231 switch qType { 232 case dnsmessage.TypeA: 233 ips, err = h.client.LookupIP(domain, dns.IPOption{ 234 IPv4Enable: true, 235 IPv6Enable: false, 236 FakeEnable: true, 237 }) 238 case dnsmessage.TypeAAAA: 239 ips, err = h.client.LookupIP(domain, dns.IPOption{ 240 IPv4Enable: false, 241 IPv6Enable: true, 242 FakeEnable: true, 243 }) 244 } 245 246 rcode := dns.RCodeFromError(err) 247 if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) { 248 newError("ip query").Base(err).WriteToLog() 249 return 250 } 251 252 if fkr0, ok := h.fdns.(dns.FakeDNSEngineRev0); ok && len(ips) > 0 && fkr0.IsIPInIPPool(net.IPAddress(ips[0])) { 253 ttl = 1 254 } 255 256 switch qType { 257 case dnsmessage.TypeA: 258 for i, ip := range ips { 259 ips[i] = ip.To4() 260 } 261 case dnsmessage.TypeAAAA: 262 for i, ip := range ips { 263 ips[i] = ip.To16() 264 } 265 } 266 267 b := buf.New() 268 rawBytes := b.Extend(buf.Size) 269 builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ 270 ID: id, 271 RCode: dnsmessage.RCode(rcode), 272 RecursionAvailable: true, 273 RecursionDesired: true, 274 Response: true, 275 Authoritative: true, 276 }) 277 builder.EnableCompression() 278 common.Must(builder.StartQuestions()) 279 common.Must(builder.Question(dnsmessage.Question{ 280 Name: dnsmessage.MustNewName(domain), 281 Class: dnsmessage.ClassINET, 282 Type: qType, 283 })) 284 common.Must(builder.StartAnswers()) 285 286 rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl} 287 for _, ip := range ips { 288 if len(ip) == net.IPv4len { 289 var r dnsmessage.AResource 290 copy(r.A[:], ip) 291 common.Must(builder.AResource(rHeader, r)) 292 } else { 293 var r dnsmessage.AAAAResource 294 copy(r.AAAA[:], ip) 295 common.Must(builder.AAAAResource(rHeader, r)) 296 } 297 } 298 msgBytes, err := builder.Finish() 299 if err != nil { 300 newError("pack message").Base(err).WriteToLog() 301 b.Release() 302 return 303 } 304 b.Resize(0, int32(len(msgBytes))) 305 306 if err := writer.WriteMessage(b); err != nil { 307 newError("write IP answer").Base(err).WriteToLog() 308 } 309 } 310 311 type outboundConn struct { 312 access sync.Mutex 313 dialer func() (stat.Connection, error) 314 315 conn net.Conn 316 connReady chan struct{} 317 } 318 319 func (c *outboundConn) dial() error { 320 conn, err := c.dialer() 321 if err != nil { 322 return err 323 } 324 c.conn = conn 325 c.connReady <- struct{}{} 326 return nil 327 } 328 329 func (c *outboundConn) Write(b []byte) (int, error) { 330 c.access.Lock() 331 332 if c.conn == nil { 333 if err := c.dial(); err != nil { 334 c.access.Unlock() 335 newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog() 336 return len(b), nil 337 } 338 } 339 340 c.access.Unlock() 341 342 return c.conn.Write(b) 343 } 344 345 func (c *outboundConn) Read(b []byte) (int, error) { 346 var conn net.Conn 347 c.access.Lock() 348 conn = c.conn 349 c.access.Unlock() 350 351 if conn == nil { 352 _, open := <-c.connReady 353 if !open { 354 return 0, io.EOF 355 } 356 conn = c.conn 357 } 358 359 return conn.Read(b) 360 } 361 362 func (c *outboundConn) Close() error { 363 c.access.Lock() 364 close(c.connReady) 365 if c.conn != nil { 366 c.conn.Close() 367 } 368 c.access.Unlock() 369 return nil 370 }