github.com/xmplusdev/xray-core@v1.8.10/proxy/dns/dns.go (about) 1 package dns 2 3 import ( 4 "context" 5 "io" 6 "sync" 7 "time" 8 9 "github.com/xmplusdev/xray-core/common" 10 "github.com/xmplusdev/xray-core/common/buf" 11 "github.com/xmplusdev/xray-core/common/errors" 12 "github.com/xmplusdev/xray-core/common/net" 13 dns_proto "github.com/xmplusdev/xray-core/common/protocol/dns" 14 "github.com/xmplusdev/xray-core/common/session" 15 "github.com/xmplusdev/xray-core/common/signal" 16 "github.com/xmplusdev/xray-core/common/task" 17 "github.com/xmplusdev/xray-core/core" 18 "github.com/xmplusdev/xray-core/features/dns" 19 "github.com/xmplusdev/xray-core/features/policy" 20 "github.com/xmplusdev/xray-core/transport" 21 "github.com/xmplusdev/xray-core/transport/internet" 22 "github.com/xmplusdev/xray-core/transport/internet/stat" 23 "golang.org/x/net/dns/dnsmessage" 24 ) 25 26 func init() { 27 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 28 h := new(Handler) 29 if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { 30 core.RequireFeatures(ctx, func(fdns dns.FakeDNSEngine) { 31 h.fdns = fdns 32 }) 33 return h.Init(config.(*Config), dnsClient, policyManager) 34 }); err != nil { 35 return nil, err 36 } 37 return h, nil 38 })) 39 } 40 41 type ownLinkVerifier interface { 42 IsOwnLink(ctx context.Context) bool 43 } 44 45 type Handler struct { 46 client dns.Client 47 fdns dns.FakeDNSEngine 48 ownLinkVerifier ownLinkVerifier 49 server net.Destination 50 timeout time.Duration 51 nonIPQuery string 52 } 53 54 func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error { 55 h.client = dnsClient 56 h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle 57 58 if v, ok := dnsClient.(ownLinkVerifier); ok { 59 h.ownLinkVerifier = v 60 } 61 62 if config.Server != nil { 63 h.server = config.Server.AsDestination() 64 } 65 h.nonIPQuery = config.Non_IPQuery 66 return nil 67 } 68 69 func (h *Handler) isOwnLink(ctx context.Context) bool { 70 return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) 71 } 72 73 func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { 74 var parser dnsmessage.Parser 75 header, err := parser.Start(b) 76 if err != nil { 77 newError("parser start").Base(err).WriteToLog() 78 return 79 } 80 81 id = header.ID 82 q, err := parser.Question() 83 if err != nil { 84 newError("question").Base(err).WriteToLog() 85 return 86 } 87 qType = q.Type 88 if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { 89 return 90 } 91 92 domain = q.Name.String() 93 r = true 94 return 95 } 96 97 // Process implements proxy.Outbound. 98 func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { 99 outbound := session.OutboundFromContext(ctx) 100 if outbound == nil || !outbound.Target.IsValid() { 101 return newError("invalid outbound") 102 } 103 outbound.Name = "dns" 104 105 srcNetwork := outbound.Target.Network 106 107 dest := outbound.Target 108 if h.server.Network != net.Network_Unknown { 109 dest.Network = h.server.Network 110 } 111 if h.server.Address != nil { 112 dest.Address = h.server.Address 113 } 114 if h.server.Port != 0 { 115 dest.Port = h.server.Port 116 } 117 118 newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx)) 119 120 conn := &outboundConn{ 121 dialer: func() (stat.Connection, error) { 122 return d.Dial(ctx, dest) 123 }, 124 connReady: make(chan struct{}, 1), 125 } 126 127 var reader dns_proto.MessageReader 128 var writer dns_proto.MessageWriter 129 if srcNetwork == net.Network_TCP { 130 reader = dns_proto.NewTCPReader(link.Reader) 131 writer = &dns_proto.TCPWriter{ 132 Writer: link.Writer, 133 } 134 } else { 135 reader = &dns_proto.UDPReader{ 136 Reader: link.Reader, 137 } 138 writer = &dns_proto.UDPWriter{ 139 Writer: link.Writer, 140 } 141 } 142 143 var connReader dns_proto.MessageReader 144 var connWriter dns_proto.MessageWriter 145 if dest.Network == net.Network_TCP { 146 connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) 147 connWriter = &dns_proto.TCPWriter{ 148 Writer: buf.NewWriter(conn), 149 } 150 } else { 151 connReader = &dns_proto.UDPReader{ 152 Reader: buf.NewPacketReader(conn), 153 } 154 connWriter = &dns_proto.UDPWriter{ 155 Writer: buf.NewWriter(conn), 156 } 157 } 158 159 if session.TimeoutOnlyFromContext(ctx) { 160 ctx, _ = context.WithCancel(context.Background()) 161 } 162 163 ctx, cancel := context.WithCancel(ctx) 164 timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) 165 166 request := func() error { 167 defer conn.Close() 168 169 for { 170 b, err := reader.ReadMessage() 171 if err == io.EOF { 172 return nil 173 } 174 175 if err != nil { 176 return err 177 } 178 179 timer.Update() 180 181 if !h.isOwnLink(ctx) { 182 isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) 183 if isIPQuery { 184 go h.handleIPQuery(id, qType, domain, writer) 185 } 186 if isIPQuery || h.nonIPQuery == "drop" || qType == 65 { 187 b.Release() 188 continue 189 } 190 } 191 192 if err := connWriter.WriteMessage(b); err != nil { 193 return err 194 } 195 } 196 } 197 198 response := func() error { 199 for { 200 b, err := connReader.ReadMessage() 201 if err == io.EOF { 202 return nil 203 } 204 205 if err != nil { 206 return err 207 } 208 209 timer.Update() 210 211 if err := writer.WriteMessage(b); err != nil { 212 return err 213 } 214 } 215 } 216 217 if err := task.Run(ctx, request, response); err != nil { 218 return newError("connection ends").Base(err) 219 } 220 221 return nil 222 } 223 224 func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { 225 var ips []net.IP 226 var err error 227 228 var ttl uint32 = 600 229 230 switch qType { 231 case dnsmessage.TypeA: 232 ips, err = h.client.LookupIP(domain, dns.IPOption{ 233 IPv4Enable: true, 234 IPv6Enable: false, 235 FakeEnable: true, 236 }) 237 case dnsmessage.TypeAAAA: 238 ips, err = h.client.LookupIP(domain, dns.IPOption{ 239 IPv4Enable: false, 240 IPv6Enable: true, 241 FakeEnable: true, 242 }) 243 } 244 245 rcode := dns.RCodeFromError(err) 246 if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) { 247 newError("ip query").Base(err).WriteToLog() 248 return 249 } 250 251 if fkr0, ok := h.fdns.(dns.FakeDNSEngineRev0); ok && len(ips) > 0 && fkr0.IsIPInIPPool(net.IPAddress(ips[0])) { 252 ttl = 1 253 } 254 255 switch qType { 256 case dnsmessage.TypeA: 257 for i, ip := range ips { 258 ips[i] = ip.To4() 259 } 260 case dnsmessage.TypeAAAA: 261 for i, ip := range ips { 262 ips[i] = ip.To16() 263 } 264 } 265 266 b := buf.New() 267 rawBytes := b.Extend(buf.Size) 268 builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ 269 ID: id, 270 RCode: dnsmessage.RCode(rcode), 271 RecursionAvailable: true, 272 RecursionDesired: true, 273 Response: true, 274 Authoritative: true, 275 }) 276 builder.EnableCompression() 277 common.Must(builder.StartQuestions()) 278 common.Must(builder.Question(dnsmessage.Question{ 279 Name: dnsmessage.MustNewName(domain), 280 Class: dnsmessage.ClassINET, 281 Type: qType, 282 })) 283 common.Must(builder.StartAnswers()) 284 285 rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl} 286 for _, ip := range ips { 287 if len(ip) == net.IPv4len { 288 var r dnsmessage.AResource 289 copy(r.A[:], ip) 290 common.Must(builder.AResource(rHeader, r)) 291 } else { 292 var r dnsmessage.AAAAResource 293 copy(r.AAAA[:], ip) 294 common.Must(builder.AAAAResource(rHeader, r)) 295 } 296 } 297 msgBytes, err := builder.Finish() 298 if err != nil { 299 newError("pack message").Base(err).WriteToLog() 300 b.Release() 301 return 302 } 303 b.Resize(0, int32(len(msgBytes))) 304 305 if err := writer.WriteMessage(b); err != nil { 306 newError("write IP answer").Base(err).WriteToLog() 307 } 308 } 309 310 type outboundConn struct { 311 access sync.Mutex 312 dialer func() (stat.Connection, error) 313 314 conn net.Conn 315 connReady chan struct{} 316 } 317 318 func (c *outboundConn) dial() error { 319 conn, err := c.dialer() 320 if err != nil { 321 return err 322 } 323 c.conn = conn 324 c.connReady <- struct{}{} 325 return nil 326 } 327 328 func (c *outboundConn) Write(b []byte) (int, error) { 329 c.access.Lock() 330 331 if c.conn == nil { 332 if err := c.dial(); err != nil { 333 c.access.Unlock() 334 newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog() 335 return len(b), nil 336 } 337 } 338 339 c.access.Unlock() 340 341 return c.conn.Write(b) 342 } 343 344 func (c *outboundConn) Read(b []byte) (int, error) { 345 var conn net.Conn 346 c.access.Lock() 347 conn = c.conn 348 c.access.Unlock() 349 350 if conn == nil { 351 _, open := <-c.connReady 352 if !open { 353 return 0, io.EOF 354 } 355 conn = c.conn 356 } 357 358 return conn.Read(b) 359 } 360 361 func (c *outboundConn) Close() error { 362 c.access.Lock() 363 close(c.connReady) 364 if c.conn != nil { 365 c.conn.Close() 366 } 367 c.access.Unlock() 368 return nil 369 }