github.com/v2fly/v2ray-core/v4@v4.45.2/proxy/dns/dns.go (about) 1 //go:build !confonly 2 // +build !confonly 3 4 package dns 5 6 import ( 7 "context" 8 "io" 9 "sync" 10 "time" 11 12 "golang.org/x/net/dns/dnsmessage" 13 14 core "github.com/v2fly/v2ray-core/v4" 15 "github.com/v2fly/v2ray-core/v4/common" 16 "github.com/v2fly/v2ray-core/v4/common/buf" 17 "github.com/v2fly/v2ray-core/v4/common/net" 18 dns_proto "github.com/v2fly/v2ray-core/v4/common/protocol/dns" 19 "github.com/v2fly/v2ray-core/v4/common/session" 20 "github.com/v2fly/v2ray-core/v4/common/signal" 21 "github.com/v2fly/v2ray-core/v4/common/task" 22 "github.com/v2fly/v2ray-core/v4/features/dns" 23 "github.com/v2fly/v2ray-core/v4/features/policy" 24 "github.com/v2fly/v2ray-core/v4/transport" 25 "github.com/v2fly/v2ray-core/v4/transport/internet" 26 ) 27 28 func init() { 29 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 30 h := new(Handler) 31 if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { 32 return h.Init(config.(*Config), dnsClient, policyManager) 33 }); err != nil { 34 return nil, err 35 } 36 return h, nil 37 })) 38 } 39 40 type ownLinkVerifier interface { 41 IsOwnLink(ctx context.Context) bool 42 } 43 44 type Handler struct { 45 client dns.Client 46 ipv4Lookup dns.IPv4Lookup 47 ipv6Lookup dns.IPv6Lookup 48 ownLinkVerifier ownLinkVerifier 49 server net.Destination 50 timeout time.Duration 51 } 52 53 func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error { 54 h.client = dnsClient 55 h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle 56 57 if ipv4lookup, ok := dnsClient.(dns.IPv4Lookup); ok { 58 h.ipv4Lookup = ipv4lookup 59 } else { 60 return newError("dns.Client doesn't implement IPv4Lookup") 61 } 62 63 if ipv6lookup, ok := dnsClient.(dns.IPv6Lookup); ok { 64 h.ipv6Lookup = ipv6lookup 65 } else { 66 return newError("dns.Client doesn't implement IPv6Lookup") 67 } 68 69 if v, ok := dnsClient.(ownLinkVerifier); ok { 70 h.ownLinkVerifier = v 71 } 72 73 if config.Server != nil { 74 h.server = config.Server.AsDestination() 75 } 76 return nil 77 } 78 79 func (h *Handler) isOwnLink(ctx context.Context) bool { 80 return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) 81 } 82 83 func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { 84 var parser dnsmessage.Parser 85 header, err := parser.Start(b) 86 if err != nil { 87 newError("parser start").Base(err).WriteToLog() 88 return 89 } 90 91 id = header.ID 92 q, err := parser.Question() 93 if err != nil { 94 newError("question").Base(err).WriteToLog() 95 return 96 } 97 qType = q.Type 98 if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { 99 return 100 } 101 102 domain = q.Name.String() 103 r = true 104 return 105 } 106 107 // Process implements proxy.Outbound. 108 func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { 109 outbound := session.OutboundFromContext(ctx) 110 if outbound == nil || !outbound.Target.IsValid() { 111 return newError("invalid outbound") 112 } 113 114 srcNetwork := outbound.Target.Network 115 116 dest := outbound.Target 117 if h.server.Network != net.Network_Unknown { 118 dest.Network = h.server.Network 119 } 120 if h.server.Address != nil { 121 dest.Address = h.server.Address 122 } 123 if h.server.Port != 0 { 124 dest.Port = h.server.Port 125 } 126 127 newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx)) 128 129 conn := &outboundConn{ 130 dialer: func() (internet.Connection, error) { 131 return d.Dial(ctx, dest) 132 }, 133 connReady: make(chan struct{}, 1), 134 } 135 136 var reader dns_proto.MessageReader 137 var writer dns_proto.MessageWriter 138 if srcNetwork == net.Network_TCP { 139 reader = dns_proto.NewTCPReader(link.Reader) 140 writer = &dns_proto.TCPWriter{ 141 Writer: link.Writer, 142 } 143 } else { 144 reader = &dns_proto.UDPReader{ 145 Reader: link.Reader, 146 } 147 writer = &dns_proto.UDPWriter{ 148 Writer: link.Writer, 149 } 150 } 151 152 var connReader dns_proto.MessageReader 153 var connWriter dns_proto.MessageWriter 154 if dest.Network == net.Network_TCP { 155 connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) 156 connWriter = &dns_proto.TCPWriter{ 157 Writer: buf.NewWriter(conn), 158 } 159 } else { 160 connReader = &dns_proto.UDPReader{ 161 Reader: buf.NewPacketReader(conn), 162 } 163 connWriter = &dns_proto.UDPWriter{ 164 Writer: buf.NewWriter(conn), 165 } 166 } 167 168 ctx, cancel := context.WithCancel(ctx) 169 timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) 170 171 request := func() error { 172 defer conn.Close() 173 174 for { 175 b, err := reader.ReadMessage() 176 if err == io.EOF { 177 return nil 178 } 179 180 if err != nil { 181 return err 182 } 183 184 timer.Update() 185 186 if !h.isOwnLink(ctx) { 187 isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) 188 if isIPQuery { 189 go h.handleIPQuery(id, qType, domain, writer) 190 continue 191 } 192 } 193 194 if err := connWriter.WriteMessage(b); err != nil { 195 return err 196 } 197 } 198 } 199 200 response := func() error { 201 for { 202 b, err := connReader.ReadMessage() 203 if err == io.EOF { 204 return nil 205 } 206 207 if err != nil { 208 return err 209 } 210 211 timer.Update() 212 213 if err := writer.WriteMessage(b); err != nil { 214 return err 215 } 216 } 217 } 218 219 if err := task.Run(ctx, request, response); err != nil { 220 return newError("connection ends").Base(err) 221 } 222 223 return nil 224 } 225 226 func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { 227 var ips []net.IP 228 var err error 229 230 var ttl uint32 = 600 231 232 // Do NOT skip FakeDNS 233 if c, ok := h.client.(dns.ClientWithIPOption); ok { 234 c.SetFakeDNSOption(true) 235 } else { 236 newError("dns.Client doesn't implement ClientWithIPOption") 237 } 238 239 switch qType { 240 case dnsmessage.TypeA: 241 ips, err = h.ipv4Lookup.LookupIPv4(domain) 242 case dnsmessage.TypeAAAA: 243 ips, err = h.ipv6Lookup.LookupIPv6(domain) 244 } 245 246 rcode := dns.RCodeFromError(err) 247 if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse { 248 newError("ip query").Base(err).WriteToLog() 249 return 250 } 251 252 b := buf.New() 253 rawBytes := b.Extend(buf.Size) 254 builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ 255 ID: id, 256 RCode: dnsmessage.RCode(rcode), 257 RecursionAvailable: true, 258 RecursionDesired: true, 259 Response: true, 260 }) 261 builder.EnableCompression() 262 common.Must(builder.StartQuestions()) 263 common.Must(builder.Question(dnsmessage.Question{ 264 Name: dnsmessage.MustNewName(domain), 265 Class: dnsmessage.ClassINET, 266 Type: qType, 267 })) 268 common.Must(builder.StartAnswers()) 269 270 rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl} 271 for _, ip := range ips { 272 if len(ip) == net.IPv4len { 273 var r dnsmessage.AResource 274 copy(r.A[:], ip) 275 common.Must(builder.AResource(rHeader, r)) 276 } else { 277 var r dnsmessage.AAAAResource 278 copy(r.AAAA[:], ip) 279 common.Must(builder.AAAAResource(rHeader, r)) 280 } 281 } 282 msgBytes, err := builder.Finish() 283 if err != nil { 284 newError("pack message").Base(err).WriteToLog() 285 b.Release() 286 return 287 } 288 b.Resize(0, int32(len(msgBytes))) 289 290 if err := writer.WriteMessage(b); err != nil { 291 newError("write IP answer").Base(err).WriteToLog() 292 } 293 } 294 295 type outboundConn struct { 296 access sync.Mutex 297 dialer func() (internet.Connection, error) 298 299 conn net.Conn 300 connReady chan struct{} 301 } 302 303 func (c *outboundConn) dial() error { 304 conn, err := c.dialer() 305 if err != nil { 306 return err 307 } 308 c.conn = conn 309 c.connReady <- struct{}{} 310 return nil 311 } 312 313 func (c *outboundConn) Write(b []byte) (int, error) { 314 c.access.Lock() 315 316 if c.conn == nil { 317 if err := c.dial(); err != nil { 318 c.access.Unlock() 319 newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog() 320 return len(b), nil 321 } 322 } 323 324 c.access.Unlock() 325 326 return c.conn.Write(b) 327 } 328 329 func (c *outboundConn) Read(b []byte) (int, error) { 330 var conn net.Conn 331 c.access.Lock() 332 conn = c.conn 333 c.access.Unlock() 334 335 if conn == nil { 336 _, open := <-c.connReady 337 if !open { 338 return 0, io.EOF 339 } 340 conn = c.conn 341 } 342 343 return conn.Read(b) 344 } 345 346 func (c *outboundConn) Close() error { 347 c.access.Lock() 348 close(c.connReady) 349 if c.conn != nil { 350 c.conn.Close() 351 } 352 c.access.Unlock() 353 return nil 354 }