github.com/moqsien/xraycore@v1.8.5/proxy/dns/dns.go (about) 1 package dns 2 3 import ( 4 "context" 5 "io" 6 "sync" 7 "time" 8 9 "github.com/moqsien/xraycore/common" 10 "github.com/moqsien/xraycore/common/buf" 11 "github.com/moqsien/xraycore/common/errors" 12 "github.com/moqsien/xraycore/common/net" 13 dns_proto "github.com/moqsien/xraycore/common/protocol/dns" 14 "github.com/moqsien/xraycore/common/session" 15 "github.com/moqsien/xraycore/common/signal" 16 "github.com/moqsien/xraycore/common/task" 17 "github.com/moqsien/xraycore/core" 18 "github.com/moqsien/xraycore/features/dns" 19 "github.com/moqsien/xraycore/features/policy" 20 "github.com/moqsien/xraycore/transport" 21 "github.com/moqsien/xraycore/transport/internet" 22 "github.com/moqsien/xraycore/transport/internet/stat" 23 "golang.org/x/net/dns/dnsmessage" 24 ) 25 26 func init() { 27 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 28 h := new(Handler) 29 if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { 30 return h.Init(config.(*Config), dnsClient, policyManager) 31 }); err != nil { 32 return nil, err 33 } 34 return h, nil 35 })) 36 } 37 38 type ownLinkVerifier interface { 39 IsOwnLink(ctx context.Context) bool 40 } 41 42 type Handler struct { 43 client dns.Client 44 ownLinkVerifier ownLinkVerifier 45 server net.Destination 46 timeout time.Duration 47 nonIPQuery string 48 } 49 50 func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error { 51 h.client = dnsClient 52 h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle 53 54 if v, ok := dnsClient.(ownLinkVerifier); ok { 55 h.ownLinkVerifier = v 56 } 57 58 if config.Server != nil { 59 h.server = config.Server.AsDestination() 60 } 61 h.nonIPQuery = config.Non_IPQuery 62 return nil 63 } 64 65 func (h *Handler) isOwnLink(ctx context.Context) bool { 66 return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) 67 } 68 69 func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { 70 var parser dnsmessage.Parser 71 header, err := parser.Start(b) 72 if err != nil { 73 newError("parser start").Base(err).WriteToLog() 74 return 75 } 76 77 id = header.ID 78 q, err := parser.Question() 79 if err != nil { 80 newError("question").Base(err).WriteToLog() 81 return 82 } 83 qType = q.Type 84 if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { 85 return 86 } 87 88 domain = q.Name.String() 89 r = true 90 return 91 } 92 93 // Process implements proxy.Outbound. 94 func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { 95 outbound := session.OutboundFromContext(ctx) 96 if outbound == nil || !outbound.Target.IsValid() { 97 return newError("invalid outbound") 98 } 99 100 srcNetwork := outbound.Target.Network 101 102 dest := outbound.Target 103 if h.server.Network != net.Network_Unknown { 104 dest.Network = h.server.Network 105 } 106 if h.server.Address != nil { 107 dest.Address = h.server.Address 108 } 109 if h.server.Port != 0 { 110 dest.Port = h.server.Port 111 } 112 113 newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx)) 114 115 conn := &outboundConn{ 116 dialer: func() (stat.Connection, error) { 117 return d.Dial(ctx, dest) 118 }, 119 connReady: make(chan struct{}, 1), 120 } 121 122 var reader dns_proto.MessageReader 123 var writer dns_proto.MessageWriter 124 if srcNetwork == net.Network_TCP { 125 reader = dns_proto.NewTCPReader(link.Reader) 126 writer = &dns_proto.TCPWriter{ 127 Writer: link.Writer, 128 } 129 } else { 130 reader = &dns_proto.UDPReader{ 131 Reader: link.Reader, 132 } 133 writer = &dns_proto.UDPWriter{ 134 Writer: link.Writer, 135 } 136 } 137 138 var connReader dns_proto.MessageReader 139 var connWriter dns_proto.MessageWriter 140 if dest.Network == net.Network_TCP { 141 connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) 142 connWriter = &dns_proto.TCPWriter{ 143 Writer: buf.NewWriter(conn), 144 } 145 } else { 146 connReader = &dns_proto.UDPReader{ 147 Reader: buf.NewPacketReader(conn), 148 } 149 connWriter = &dns_proto.UDPWriter{ 150 Writer: buf.NewWriter(conn), 151 } 152 } 153 154 if session.TimeoutOnlyFromContext(ctx) { 155 ctx, _ = context.WithCancel(context.Background()) 156 } 157 158 ctx, cancel := context.WithCancel(ctx) 159 timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) 160 161 request := func() error { 162 defer conn.Close() 163 164 for { 165 b, err := reader.ReadMessage() 166 if err == io.EOF { 167 return nil 168 } 169 170 if err != nil { 171 return err 172 } 173 174 timer.Update() 175 176 if !h.isOwnLink(ctx) { 177 isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) 178 if isIPQuery { 179 go h.handleIPQuery(id, qType, domain, writer) 180 } 181 if isIPQuery || h.nonIPQuery == "drop" { 182 b.Release() 183 continue 184 } 185 } 186 187 if err := connWriter.WriteMessage(b); err != nil { 188 return err 189 } 190 } 191 } 192 193 response := func() error { 194 for { 195 b, err := connReader.ReadMessage() 196 if err == io.EOF { 197 return nil 198 } 199 200 if err != nil { 201 return err 202 } 203 204 timer.Update() 205 206 if err := writer.WriteMessage(b); err != nil { 207 return err 208 } 209 } 210 } 211 212 if err := task.Run(ctx, request, response); err != nil { 213 return newError("connection ends").Base(err) 214 } 215 216 return nil 217 } 218 219 func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { 220 var ips []net.IP 221 var err error 222 223 var ttl uint32 = 600 224 225 switch qType { 226 case dnsmessage.TypeA: 227 ips, err = h.client.LookupIP(domain, dns.IPOption{ 228 IPv4Enable: true, 229 IPv6Enable: false, 230 FakeEnable: true, 231 }) 232 case dnsmessage.TypeAAAA: 233 ips, err = h.client.LookupIP(domain, dns.IPOption{ 234 IPv4Enable: false, 235 IPv6Enable: true, 236 FakeEnable: true, 237 }) 238 } 239 240 rcode := dns.RCodeFromError(err) 241 if rcode == 0 && len(ips) == 0 && !errors.AllEqual(dns.ErrEmptyResponse, errors.Cause(err)) { 242 newError("ip query").Base(err).WriteToLog() 243 return 244 } 245 246 switch qType { 247 case dnsmessage.TypeA: 248 for i, ip := range ips { 249 ips[i] = ip.To4() 250 } 251 case dnsmessage.TypeAAAA: 252 for i, ip := range ips { 253 ips[i] = ip.To16() 254 } 255 } 256 257 b := buf.New() 258 rawBytes := b.Extend(buf.Size) 259 builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ 260 ID: id, 261 RCode: dnsmessage.RCode(rcode), 262 RecursionAvailable: true, 263 RecursionDesired: true, 264 Response: true, 265 Authoritative: true, 266 }) 267 builder.EnableCompression() 268 common.Must(builder.StartQuestions()) 269 common.Must(builder.Question(dnsmessage.Question{ 270 Name: dnsmessage.MustNewName(domain), 271 Class: dnsmessage.ClassINET, 272 Type: qType, 273 })) 274 common.Must(builder.StartAnswers()) 275 276 rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl} 277 for _, ip := range ips { 278 if len(ip) == net.IPv4len { 279 var r dnsmessage.AResource 280 copy(r.A[:], ip) 281 common.Must(builder.AResource(rHeader, r)) 282 } else { 283 var r dnsmessage.AAAAResource 284 copy(r.AAAA[:], ip) 285 common.Must(builder.AAAAResource(rHeader, r)) 286 } 287 } 288 msgBytes, err := builder.Finish() 289 if err != nil { 290 newError("pack message").Base(err).WriteToLog() 291 b.Release() 292 return 293 } 294 b.Resize(0, int32(len(msgBytes))) 295 296 if err := writer.WriteMessage(b); err != nil { 297 newError("write IP answer").Base(err).WriteToLog() 298 } 299 } 300 301 type outboundConn struct { 302 access sync.Mutex 303 dialer func() (stat.Connection, error) 304 305 conn net.Conn 306 connReady chan struct{} 307 } 308 309 func (c *outboundConn) dial() error { 310 conn, err := c.dialer() 311 if err != nil { 312 return err 313 } 314 c.conn = conn 315 c.connReady <- struct{}{} 316 return nil 317 } 318 319 func (c *outboundConn) Write(b []byte) (int, error) { 320 c.access.Lock() 321 322 if c.conn == nil { 323 if err := c.dial(); err != nil { 324 c.access.Unlock() 325 newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog() 326 return len(b), nil 327 } 328 } 329 330 c.access.Unlock() 331 332 return c.conn.Write(b) 333 } 334 335 func (c *outboundConn) Read(b []byte) (int, error) { 336 var conn net.Conn 337 c.access.Lock() 338 conn = c.conn 339 c.access.Unlock() 340 341 if conn == nil { 342 _, open := <-c.connReady 343 if !open { 344 return 0, io.EOF 345 } 346 conn = c.conn 347 } 348 349 return conn.Read(b) 350 } 351 352 func (c *outboundConn) Close() error { 353 c.access.Lock() 354 close(c.connReady) 355 if c.conn != nil { 356 c.conn.Close() 357 } 358 c.access.Unlock() 359 return nil 360 }