github.com/EagleQL/Xray-core@v1.4.3/proxy/dns/dns.go (about) 1 package dns 2 3 import ( 4 "context" 5 "io" 6 "sync" 7 8 "golang.org/x/net/dns/dnsmessage" 9 10 "github.com/xtls/xray-core/common" 11 "github.com/xtls/xray-core/common/buf" 12 "github.com/xtls/xray-core/common/net" 13 dns_proto "github.com/xtls/xray-core/common/protocol/dns" 14 "github.com/xtls/xray-core/common/session" 15 "github.com/xtls/xray-core/common/task" 16 "github.com/xtls/xray-core/core" 17 "github.com/xtls/xray-core/features/dns" 18 "github.com/xtls/xray-core/transport" 19 "github.com/xtls/xray-core/transport/internet" 20 ) 21 22 func init() { 23 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 24 h := new(Handler) 25 if err := core.RequireFeatures(ctx, func(dnsClient dns.Client) error { 26 return h.Init(config.(*Config), dnsClient) 27 }); err != nil { 28 return nil, err 29 } 30 return h, nil 31 })) 32 } 33 34 type ownLinkVerifier interface { 35 IsOwnLink(ctx context.Context) bool 36 } 37 38 type Handler struct { 39 client dns.Client 40 ownLinkVerifier ownLinkVerifier 41 server net.Destination 42 } 43 44 func (h *Handler) Init(config *Config, dnsClient dns.Client) error { 45 h.client = dnsClient 46 if v, ok := dnsClient.(ownLinkVerifier); ok { 47 h.ownLinkVerifier = v 48 } 49 50 if config.Server != nil { 51 h.server = config.Server.AsDestination() 52 } 53 return nil 54 } 55 56 func (h *Handler) isOwnLink(ctx context.Context) bool { 57 return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) 58 } 59 60 func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { 61 var parser dnsmessage.Parser 62 header, err := parser.Start(b) 63 if err != nil { 64 newError("parser start").Base(err).WriteToLog() 65 return 66 } 67 68 id = header.ID 69 q, err := parser.Question() 70 if err != nil { 71 newError("question").Base(err).WriteToLog() 72 return 73 } 74 qType = q.Type 75 if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { 76 return 77 } 78 79 domain = q.Name.String() 80 r = true 81 return 82 } 83 84 // Process implements proxy.Outbound. 85 func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { 86 outbound := session.OutboundFromContext(ctx) 87 if outbound == nil || !outbound.Target.IsValid() { 88 return newError("invalid outbound") 89 } 90 91 srcNetwork := outbound.Target.Network 92 93 dest := outbound.Target 94 if h.server.Network != net.Network_Unknown { 95 dest.Network = h.server.Network 96 } 97 if h.server.Address != nil { 98 dest.Address = h.server.Address 99 } 100 if h.server.Port != 0 { 101 dest.Port = h.server.Port 102 } 103 104 newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx)) 105 106 conn := &outboundConn{ 107 dialer: func() (internet.Connection, error) { 108 return d.Dial(ctx, dest) 109 }, 110 connReady: make(chan struct{}, 1), 111 } 112 113 var reader dns_proto.MessageReader 114 var writer dns_proto.MessageWriter 115 if srcNetwork == net.Network_TCP { 116 reader = dns_proto.NewTCPReader(link.Reader) 117 writer = &dns_proto.TCPWriter{ 118 Writer: link.Writer, 119 } 120 } else { 121 reader = &dns_proto.UDPReader{ 122 Reader: link.Reader, 123 } 124 writer = &dns_proto.UDPWriter{ 125 Writer: link.Writer, 126 } 127 } 128 129 var connReader dns_proto.MessageReader 130 var connWriter dns_proto.MessageWriter 131 if dest.Network == net.Network_TCP { 132 connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) 133 connWriter = &dns_proto.TCPWriter{ 134 Writer: buf.NewWriter(conn), 135 } 136 } else { 137 connReader = &dns_proto.UDPReader{ 138 Reader: buf.NewPacketReader(conn), 139 } 140 connWriter = &dns_proto.UDPWriter{ 141 Writer: buf.NewWriter(conn), 142 } 143 } 144 145 request := func() error { 146 defer conn.Close() 147 148 for { 149 b, err := reader.ReadMessage() 150 if err == io.EOF { 151 return nil 152 } 153 154 if err != nil { 155 return err 156 } 157 158 if !h.isOwnLink(ctx) { 159 isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) 160 if isIPQuery { 161 go h.handleIPQuery(id, qType, domain, writer) 162 continue 163 } 164 } 165 166 if err := connWriter.WriteMessage(b); err != nil { 167 return err 168 } 169 } 170 } 171 172 response := func() error { 173 for { 174 b, err := connReader.ReadMessage() 175 if err == io.EOF { 176 return nil 177 } 178 179 if err != nil { 180 return err 181 } 182 183 if err := writer.WriteMessage(b); err != nil { 184 return err 185 } 186 } 187 } 188 189 if err := task.Run(ctx, request, response); err != nil { 190 return newError("connection ends").Base(err) 191 } 192 193 return nil 194 } 195 196 func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { 197 var ips []net.IP 198 var err error 199 200 var ttl uint32 = 600 201 202 switch qType { 203 case dnsmessage.TypeA: 204 ips, err = h.client.LookupIP(domain, dns.IPOption{ 205 IPv4Enable: true, 206 IPv6Enable: false, 207 FakeEnable: true, 208 }) 209 case dnsmessage.TypeAAAA: 210 ips, err = h.client.LookupIP(domain, dns.IPOption{ 211 IPv4Enable: false, 212 IPv6Enable: true, 213 FakeEnable: true, 214 }) 215 } 216 217 rcode := dns.RCodeFromError(err) 218 if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse { 219 newError("ip query").Base(err).WriteToLog() 220 return 221 } 222 223 b := buf.New() 224 rawBytes := b.Extend(buf.Size) 225 builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ 226 ID: id, 227 RCode: dnsmessage.RCode(rcode), 228 RecursionAvailable: true, 229 RecursionDesired: true, 230 Response: true, 231 Authoritative: true, 232 }) 233 builder.EnableCompression() 234 common.Must(builder.StartQuestions()) 235 common.Must(builder.Question(dnsmessage.Question{ 236 Name: dnsmessage.MustNewName(domain), 237 Class: dnsmessage.ClassINET, 238 Type: qType, 239 })) 240 common.Must(builder.StartAnswers()) 241 242 rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl} 243 for _, ip := range ips { 244 if len(ip) == net.IPv4len { 245 var r dnsmessage.AResource 246 copy(r.A[:], ip) 247 common.Must(builder.AResource(rHeader, r)) 248 } else { 249 var r dnsmessage.AAAAResource 250 copy(r.AAAA[:], ip) 251 common.Must(builder.AAAAResource(rHeader, r)) 252 } 253 } 254 msgBytes, err := builder.Finish() 255 if err != nil { 256 newError("pack message").Base(err).WriteToLog() 257 b.Release() 258 return 259 } 260 b.Resize(0, int32(len(msgBytes))) 261 262 if err := writer.WriteMessage(b); err != nil { 263 newError("write IP answer").Base(err).WriteToLog() 264 } 265 } 266 267 type outboundConn struct { 268 access sync.Mutex 269 dialer func() (internet.Connection, error) 270 271 conn net.Conn 272 connReady chan struct{} 273 } 274 275 func (c *outboundConn) dial() error { 276 conn, err := c.dialer() 277 if err != nil { 278 return err 279 } 280 c.conn = conn 281 c.connReady <- struct{}{} 282 return nil 283 } 284 285 func (c *outboundConn) Write(b []byte) (int, error) { 286 c.access.Lock() 287 288 if c.conn == nil { 289 if err := c.dial(); err != nil { 290 c.access.Unlock() 291 newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog() 292 return len(b), nil 293 } 294 } 295 296 c.access.Unlock() 297 298 return c.conn.Write(b) 299 } 300 301 func (c *outboundConn) Read(b []byte) (int, error) { 302 var conn net.Conn 303 c.access.Lock() 304 conn = c.conn 305 c.access.Unlock() 306 307 if conn == nil { 308 _, open := <-c.connReady 309 if !open { 310 return 0, io.EOF 311 } 312 conn = c.conn 313 } 314 315 return conn.Read(b) 316 } 317 318 func (c *outboundConn) Close() error { 319 c.access.Lock() 320 close(c.connReady) 321 if c.conn != nil { 322 c.conn.Close() 323 } 324 c.access.Unlock() 325 return nil 326 }