github.com/xraypb/Xray-core@v1.8.1/proxy/dns/dns.go (about) 1 package dns 2 3 import ( 4 "context" 5 "io" 6 "sync" 7 "time" 8 9 "github.com/xraypb/Xray-core/common" 10 "github.com/xraypb/Xray-core/common/buf" 11 "github.com/xraypb/Xray-core/common/net" 12 dns_proto "github.com/xraypb/Xray-core/common/protocol/dns" 13 "github.com/xraypb/Xray-core/common/session" 14 "github.com/xraypb/Xray-core/common/signal" 15 "github.com/xraypb/Xray-core/common/task" 16 "github.com/xraypb/Xray-core/core" 17 "github.com/xraypb/Xray-core/features/dns" 18 "github.com/xraypb/Xray-core/features/policy" 19 "github.com/xraypb/Xray-core/transport" 20 "github.com/xraypb/Xray-core/transport/internet" 21 "github.com/xraypb/Xray-core/transport/internet/stat" 22 "golang.org/x/net/dns/dnsmessage" 23 ) 24 25 func init() { 26 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 27 h := new(Handler) 28 if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { 29 return h.Init(config.(*Config), dnsClient, policyManager) 30 }); err != nil { 31 return nil, err 32 } 33 return h, nil 34 })) 35 } 36 37 type ownLinkVerifier interface { 38 IsOwnLink(ctx context.Context) bool 39 } 40 41 type Handler struct { 42 client dns.Client 43 ownLinkVerifier ownLinkVerifier 44 server net.Destination 45 timeout time.Duration 46 } 47 48 func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error { 49 h.client = dnsClient 50 h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle 51 52 if v, ok := dnsClient.(ownLinkVerifier); ok { 53 h.ownLinkVerifier = v 54 } 55 56 if config.Server != nil { 57 h.server = config.Server.AsDestination() 58 } 59 return nil 60 } 61 62 func (h *Handler) isOwnLink(ctx context.Context) bool { 63 return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) 64 } 65 66 func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { 67 var parser dnsmessage.Parser 68 header, err := parser.Start(b) 69 if err != nil { 70 newError("parser start").Base(err).WriteToLog() 71 return 72 } 73 74 id = header.ID 75 q, err := parser.Question() 76 if err != nil { 77 newError("question").Base(err).WriteToLog() 78 return 79 } 80 qType = q.Type 81 if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { 82 return 83 } 84 85 domain = q.Name.String() 86 r = true 87 return 88 } 89 90 // Process implements proxy.Outbound. 91 func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { 92 outbound := session.OutboundFromContext(ctx) 93 if outbound == nil || !outbound.Target.IsValid() { 94 return newError("invalid outbound") 95 } 96 97 srcNetwork := outbound.Target.Network 98 99 dest := outbound.Target 100 if h.server.Network != net.Network_Unknown { 101 dest.Network = h.server.Network 102 } 103 if h.server.Address != nil { 104 dest.Address = h.server.Address 105 } 106 if h.server.Port != 0 { 107 dest.Port = h.server.Port 108 } 109 110 newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx)) 111 112 conn := &outboundConn{ 113 dialer: func() (stat.Connection, error) { 114 return d.Dial(ctx, dest) 115 }, 116 connReady: make(chan struct{}, 1), 117 } 118 119 var reader dns_proto.MessageReader 120 var writer dns_proto.MessageWriter 121 if srcNetwork == net.Network_TCP { 122 reader = dns_proto.NewTCPReader(link.Reader) 123 writer = &dns_proto.TCPWriter{ 124 Writer: link.Writer, 125 } 126 } else { 127 reader = &dns_proto.UDPReader{ 128 Reader: link.Reader, 129 } 130 writer = &dns_proto.UDPWriter{ 131 Writer: link.Writer, 132 } 133 } 134 135 var connReader dns_proto.MessageReader 136 var connWriter dns_proto.MessageWriter 137 if dest.Network == net.Network_TCP { 138 connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) 139 connWriter = &dns_proto.TCPWriter{ 140 Writer: buf.NewWriter(conn), 141 } 142 } else { 143 connReader = &dns_proto.UDPReader{ 144 Reader: buf.NewPacketReader(conn), 145 } 146 connWriter = &dns_proto.UDPWriter{ 147 Writer: buf.NewWriter(conn), 148 } 149 } 150 151 if session.TimeoutOnlyFromContext(ctx) { 152 ctx, _ = context.WithCancel(context.Background()) 153 } 154 155 ctx, cancel := context.WithCancel(ctx) 156 timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) 157 158 request := func() error { 159 defer conn.Close() 160 161 for { 162 b, err := reader.ReadMessage() 163 if err == io.EOF { 164 return nil 165 } 166 167 if err != nil { 168 return err 169 } 170 171 timer.Update() 172 173 if !h.isOwnLink(ctx) { 174 isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) 175 if isIPQuery { 176 go h.handleIPQuery(id, qType, domain, writer) 177 continue 178 } 179 } 180 181 if err := connWriter.WriteMessage(b); err != nil { 182 return err 183 } 184 } 185 } 186 187 response := func() error { 188 for { 189 b, err := connReader.ReadMessage() 190 if err == io.EOF { 191 return nil 192 } 193 194 if err != nil { 195 return err 196 } 197 198 timer.Update() 199 200 if err := writer.WriteMessage(b); err != nil { 201 return err 202 } 203 } 204 } 205 206 if err := task.Run(ctx, request, response); err != nil { 207 return newError("connection ends").Base(err) 208 } 209 210 return nil 211 } 212 213 func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { 214 var ips []net.IP 215 var err error 216 217 var ttl uint32 = 600 218 219 switch qType { 220 case dnsmessage.TypeA: 221 ips, err = h.client.LookupIP(domain, dns.IPOption{ 222 IPv4Enable: true, 223 IPv6Enable: false, 224 FakeEnable: true, 225 }) 226 case dnsmessage.TypeAAAA: 227 ips, err = h.client.LookupIP(domain, dns.IPOption{ 228 IPv4Enable: false, 229 IPv6Enable: true, 230 FakeEnable: true, 231 }) 232 } 233 234 rcode := dns.RCodeFromError(err) 235 if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse { 236 newError("ip query").Base(err).WriteToLog() 237 return 238 } 239 240 switch qType { 241 case dnsmessage.TypeA: 242 for i, ip := range ips { 243 ips[i] = ip.To4() 244 } 245 case dnsmessage.TypeAAAA: 246 for i, ip := range ips { 247 ips[i] = ip.To16() 248 } 249 } 250 251 b := buf.New() 252 rawBytes := b.Extend(buf.Size) 253 builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ 254 ID: id, 255 RCode: dnsmessage.RCode(rcode), 256 RecursionAvailable: true, 257 RecursionDesired: true, 258 Response: true, 259 Authoritative: true, 260 }) 261 builder.EnableCompression() 262 common.Must(builder.StartQuestions()) 263 common.Must(builder.Question(dnsmessage.Question{ 264 Name: dnsmessage.MustNewName(domain), 265 Class: dnsmessage.ClassINET, 266 Type: qType, 267 })) 268 common.Must(builder.StartAnswers()) 269 270 rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl} 271 for _, ip := range ips { 272 if len(ip) == net.IPv4len { 273 var r dnsmessage.AResource 274 copy(r.A[:], ip) 275 common.Must(builder.AResource(rHeader, r)) 276 } else { 277 var r dnsmessage.AAAAResource 278 copy(r.AAAA[:], ip) 279 common.Must(builder.AAAAResource(rHeader, r)) 280 } 281 } 282 msgBytes, err := builder.Finish() 283 if err != nil { 284 newError("pack message").Base(err).WriteToLog() 285 b.Release() 286 return 287 } 288 b.Resize(0, int32(len(msgBytes))) 289 290 if err := writer.WriteMessage(b); err != nil { 291 newError("write IP answer").Base(err).WriteToLog() 292 } 293 } 294 295 type outboundConn struct { 296 access sync.Mutex 297 dialer func() (stat.Connection, error) 298 299 conn net.Conn 300 connReady chan struct{} 301 } 302 303 func (c *outboundConn) dial() error { 304 conn, err := c.dialer() 305 if err != nil { 306 return err 307 } 308 c.conn = conn 309 c.connReady <- struct{}{} 310 return nil 311 } 312 313 func (c *outboundConn) Write(b []byte) (int, error) { 314 c.access.Lock() 315 316 if c.conn == nil { 317 if err := c.dial(); err != nil { 318 c.access.Unlock() 319 newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog() 320 return len(b), nil 321 } 322 } 323 324 c.access.Unlock() 325 326 return c.conn.Write(b) 327 } 328 329 func (c *outboundConn) Read(b []byte) (int, error) { 330 var conn net.Conn 331 c.access.Lock() 332 conn = c.conn 333 c.access.Unlock() 334 335 if conn == nil { 336 _, open := <-c.connReady 337 if !open { 338 return 0, io.EOF 339 } 340 conn = c.conn 341 } 342 343 return conn.Read(b) 344 } 345 346 func (c *outboundConn) Close() error { 347 c.access.Lock() 348 close(c.connReady) 349 if c.conn != nil { 350 c.conn.Close() 351 } 352 c.access.Unlock() 353 return nil 354 }