github.com/xraypb/xray-core@v1.6.6/proxy/dns/dns.go (about) 1 package dns 2 3 import ( 4 "context" 5 "io" 6 "sync" 7 "time" 8 9 "github.com/xraypb/xray-core/common" 10 "github.com/xraypb/xray-core/common/buf" 11 "github.com/xraypb/xray-core/common/net" 12 dns_proto "github.com/xraypb/xray-core/common/protocol/dns" 13 "github.com/xraypb/xray-core/common/session" 14 "github.com/xraypb/xray-core/common/signal" 15 "github.com/xraypb/xray-core/common/task" 16 "github.com/xraypb/xray-core/core" 17 "github.com/xraypb/xray-core/features/dns" 18 "github.com/xraypb/xray-core/features/policy" 19 "github.com/xraypb/xray-core/transport" 20 "github.com/xraypb/xray-core/transport/internet" 21 "github.com/xraypb/xray-core/transport/internet/stat" 22 "golang.org/x/net/dns/dnsmessage" 23 ) 24 25 func init() { 26 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 27 h := new(Handler) 28 if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { 29 return h.Init(config.(*Config), dnsClient, policyManager) 30 }); err != nil { 31 return nil, err 32 } 33 return h, nil 34 })) 35 } 36 37 type ownLinkVerifier interface { 38 IsOwnLink(ctx context.Context) bool 39 } 40 41 type Handler struct { 42 client dns.Client 43 ownLinkVerifier ownLinkVerifier 44 server net.Destination 45 timeout time.Duration 46 } 47 48 func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error { 49 h.client = dnsClient 50 h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle 51 52 if v, ok := dnsClient.(ownLinkVerifier); ok { 53 h.ownLinkVerifier = v 54 } 55 56 if config.Server != nil { 57 h.server = config.Server.AsDestination() 58 } 59 return nil 60 } 61 62 func (h *Handler) isOwnLink(ctx context.Context) bool { 63 return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) 64 } 65 66 func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { 67 var parser dnsmessage.Parser 68 header, err := parser.Start(b) 69 if err != nil { 70 newError("parser start").Base(err).WriteToLog() 71 return 72 } 73 74 id = header.ID 75 q, err := parser.Question() 76 if err != nil { 77 newError("question").Base(err).WriteToLog() 78 return 79 } 80 qType = q.Type 81 if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { 82 return 83 } 84 85 domain = q.Name.String() 86 r = true 87 return 88 } 89 90 // Process implements proxy.Outbound. 91 func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { 92 outbound := session.OutboundFromContext(ctx) 93 if outbound == nil || !outbound.Target.IsValid() { 94 return newError("invalid outbound") 95 } 96 97 srcNetwork := outbound.Target.Network 98 99 dest := outbound.Target 100 if h.server.Network != net.Network_Unknown { 101 dest.Network = h.server.Network 102 } 103 if h.server.Address != nil { 104 dest.Address = h.server.Address 105 } 106 if h.server.Port != 0 { 107 dest.Port = h.server.Port 108 } 109 110 newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx)) 111 112 conn := &outboundConn{ 113 dialer: func() (stat.Connection, error) { 114 return d.Dial(ctx, dest) 115 }, 116 connReady: make(chan struct{}, 1), 117 } 118 119 var reader dns_proto.MessageReader 120 var writer dns_proto.MessageWriter 121 if srcNetwork == net.Network_TCP { 122 reader = dns_proto.NewTCPReader(link.Reader) 123 writer = &dns_proto.TCPWriter{ 124 Writer: link.Writer, 125 } 126 } else { 127 reader = &dns_proto.UDPReader{ 128 Reader: link.Reader, 129 } 130 writer = &dns_proto.UDPWriter{ 131 Writer: link.Writer, 132 } 133 } 134 135 var connReader dns_proto.MessageReader 136 var connWriter dns_proto.MessageWriter 137 if dest.Network == net.Network_TCP { 138 connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) 139 connWriter = &dns_proto.TCPWriter{ 140 Writer: buf.NewWriter(conn), 141 } 142 } else { 143 connReader = &dns_proto.UDPReader{ 144 Reader: buf.NewPacketReader(conn), 145 } 146 connWriter = &dns_proto.UDPWriter{ 147 Writer: buf.NewWriter(conn), 148 } 149 } 150 151 ctx, cancel := context.WithCancel(ctx) 152 timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) 153 154 request := func() error { 155 defer conn.Close() 156 157 for { 158 b, err := reader.ReadMessage() 159 if err == io.EOF { 160 return nil 161 } 162 163 if err != nil { 164 return err 165 } 166 167 timer.Update() 168 169 if !h.isOwnLink(ctx) { 170 isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) 171 if isIPQuery { 172 go h.handleIPQuery(id, qType, domain, writer) 173 continue 174 } 175 } 176 177 if err := connWriter.WriteMessage(b); err != nil { 178 return err 179 } 180 } 181 } 182 183 response := func() error { 184 for { 185 b, err := connReader.ReadMessage() 186 if err == io.EOF { 187 return nil 188 } 189 190 if err != nil { 191 return err 192 } 193 194 timer.Update() 195 196 if err := writer.WriteMessage(b); err != nil { 197 return err 198 } 199 } 200 } 201 202 if err := task.Run(ctx, request, response); err != nil { 203 return newError("connection ends").Base(err) 204 } 205 206 return nil 207 } 208 209 func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { 210 var ips []net.IP 211 var err error 212 213 var ttl uint32 = 600 214 215 switch qType { 216 case dnsmessage.TypeA: 217 ips, err = h.client.LookupIP(domain, dns.IPOption{ 218 IPv4Enable: true, 219 IPv6Enable: false, 220 FakeEnable: true, 221 }) 222 case dnsmessage.TypeAAAA: 223 ips, err = h.client.LookupIP(domain, dns.IPOption{ 224 IPv4Enable: false, 225 IPv6Enable: true, 226 FakeEnable: true, 227 }) 228 } 229 230 rcode := dns.RCodeFromError(err) 231 if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse { 232 newError("ip query").Base(err).WriteToLog() 233 return 234 } 235 236 switch qType { 237 case dnsmessage.TypeA: 238 for i, ip := range ips { 239 ips[i] = ip.To4() 240 } 241 case dnsmessage.TypeAAAA: 242 for i, ip := range ips { 243 ips[i] = ip.To16() 244 } 245 } 246 247 b := buf.New() 248 rawBytes := b.Extend(buf.Size) 249 builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ 250 ID: id, 251 RCode: dnsmessage.RCode(rcode), 252 RecursionAvailable: true, 253 RecursionDesired: true, 254 Response: true, 255 Authoritative: true, 256 }) 257 builder.EnableCompression() 258 common.Must(builder.StartQuestions()) 259 common.Must(builder.Question(dnsmessage.Question{ 260 Name: dnsmessage.MustNewName(domain), 261 Class: dnsmessage.ClassINET, 262 Type: qType, 263 })) 264 common.Must(builder.StartAnswers()) 265 266 rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl} 267 for _, ip := range ips { 268 if len(ip) == net.IPv4len { 269 var r dnsmessage.AResource 270 copy(r.A[:], ip) 271 common.Must(builder.AResource(rHeader, r)) 272 } else { 273 var r dnsmessage.AAAAResource 274 copy(r.AAAA[:], ip) 275 common.Must(builder.AAAAResource(rHeader, r)) 276 } 277 } 278 msgBytes, err := builder.Finish() 279 if err != nil { 280 newError("pack message").Base(err).WriteToLog() 281 b.Release() 282 return 283 } 284 b.Resize(0, int32(len(msgBytes))) 285 286 if err := writer.WriteMessage(b); err != nil { 287 newError("write IP answer").Base(err).WriteToLog() 288 } 289 } 290 291 type outboundConn struct { 292 access sync.Mutex 293 dialer func() (stat.Connection, error) 294 295 conn net.Conn 296 connReady chan struct{} 297 } 298 299 func (c *outboundConn) dial() error { 300 conn, err := c.dialer() 301 if err != nil { 302 return err 303 } 304 c.conn = conn 305 c.connReady <- struct{}{} 306 return nil 307 } 308 309 func (c *outboundConn) Write(b []byte) (int, error) { 310 c.access.Lock() 311 312 if c.conn == nil { 313 if err := c.dial(); err != nil { 314 c.access.Unlock() 315 newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog() 316 return len(b), nil 317 } 318 } 319 320 c.access.Unlock() 321 322 return c.conn.Write(b) 323 } 324 325 func (c *outboundConn) Read(b []byte) (int, error) { 326 var conn net.Conn 327 c.access.Lock() 328 conn = c.conn 329 c.access.Unlock() 330 331 if conn == nil { 332 _, open := <-c.connReady 333 if !open { 334 return 0, io.EOF 335 } 336 conn = c.conn 337 } 338 339 return conn.Read(b) 340 } 341 342 func (c *outboundConn) Close() error { 343 c.access.Lock() 344 close(c.connReady) 345 if c.conn != nil { 346 c.conn.Close() 347 } 348 c.access.Unlock() 349 return nil 350 }