github.com/Uhtred009/v2ray-core-1@v4.31.2+incompatible/proxy/dns/dns.go (about) 1 // +build !confonly 2 3 package dns 4 5 import ( 6 "context" 7 "io" 8 "sync" 9 10 "golang.org/x/net/dns/dnsmessage" 11 12 "v2ray.com/core" 13 "v2ray.com/core/common" 14 "v2ray.com/core/common/buf" 15 "v2ray.com/core/common/net" 16 dns_proto "v2ray.com/core/common/protocol/dns" 17 "v2ray.com/core/common/session" 18 "v2ray.com/core/common/task" 19 "v2ray.com/core/features/dns" 20 "v2ray.com/core/transport" 21 "v2ray.com/core/transport/internet" 22 ) 23 24 func init() { 25 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 26 h := new(Handler) 27 if err := core.RequireFeatures(ctx, func(dnsClient dns.Client) error { 28 return h.Init(config.(*Config), dnsClient) 29 }); err != nil { 30 return nil, err 31 } 32 return h, nil 33 })) 34 } 35 36 type ownLinkVerifier interface { 37 IsOwnLink(ctx context.Context) bool 38 } 39 40 type Handler struct { 41 ipv4Lookup dns.IPv4Lookup 42 ipv6Lookup dns.IPv6Lookup 43 ownLinkVerifier ownLinkVerifier 44 server net.Destination 45 } 46 47 func (h *Handler) Init(config *Config, dnsClient dns.Client) error { 48 ipv4lookup, ok := dnsClient.(dns.IPv4Lookup) 49 if !ok { 50 return newError("dns.Client doesn't implement IPv4Lookup") 51 } 52 h.ipv4Lookup = ipv4lookup 53 54 ipv6lookup, ok := dnsClient.(dns.IPv6Lookup) 55 if !ok { 56 return newError("dns.Client doesn't implement IPv6Lookup") 57 } 58 h.ipv6Lookup = ipv6lookup 59 60 if v, ok := dnsClient.(ownLinkVerifier); ok { 61 h.ownLinkVerifier = v 62 } 63 64 if config.Server != nil { 65 h.server = config.Server.AsDestination() 66 } 67 return nil 68 } 69 70 func (h *Handler) isOwnLink(ctx context.Context) bool { 71 return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) 72 } 73 74 func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { 75 var parser dnsmessage.Parser 76 header, err := parser.Start(b) 77 if err != nil { 78 newError("parser start").Base(err).WriteToLog() 79 return 80 } 81 82 id = header.ID 83 q, err := parser.Question() 84 if err != nil { 85 newError("question").Base(err).WriteToLog() 86 return 87 } 88 qType = q.Type 89 if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { 90 return 91 } 92 93 domain = q.Name.String() 94 r = true 95 return 96 } 97 98 // Process implements proxy.Outbound. 99 func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { 100 outbound := session.OutboundFromContext(ctx) 101 if outbound == nil || !outbound.Target.IsValid() { 102 return newError("invalid outbound") 103 } 104 105 srcNetwork := outbound.Target.Network 106 107 dest := outbound.Target 108 if h.server.Network != net.Network_Unknown { 109 dest.Network = h.server.Network 110 } 111 if h.server.Address != nil { 112 dest.Address = h.server.Address 113 } 114 if h.server.Port != 0 { 115 dest.Port = h.server.Port 116 } 117 118 newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx)) 119 120 conn := &outboundConn{ 121 dialer: func() (internet.Connection, error) { 122 return d.Dial(ctx, dest) 123 }, 124 connReady: make(chan struct{}, 1), 125 } 126 127 var reader dns_proto.MessageReader 128 var writer dns_proto.MessageWriter 129 if srcNetwork == net.Network_TCP { 130 reader = dns_proto.NewTCPReader(link.Reader) 131 writer = &dns_proto.TCPWriter{ 132 Writer: link.Writer, 133 } 134 } else { 135 reader = &dns_proto.UDPReader{ 136 Reader: link.Reader, 137 } 138 writer = &dns_proto.UDPWriter{ 139 Writer: link.Writer, 140 } 141 } 142 143 var connReader dns_proto.MessageReader 144 var connWriter dns_proto.MessageWriter 145 if dest.Network == net.Network_TCP { 146 connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) 147 connWriter = &dns_proto.TCPWriter{ 148 Writer: buf.NewWriter(conn), 149 } 150 } else { 151 connReader = &dns_proto.UDPReader{ 152 Reader: buf.NewPacketReader(conn), 153 } 154 connWriter = &dns_proto.UDPWriter{ 155 Writer: buf.NewWriter(conn), 156 } 157 } 158 159 request := func() error { 160 defer conn.Close() 161 162 for { 163 b, err := reader.ReadMessage() 164 if err == io.EOF { 165 return nil 166 } 167 168 if err != nil { 169 return err 170 } 171 172 if !h.isOwnLink(ctx) { 173 isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) 174 if isIPQuery { 175 go h.handleIPQuery(id, qType, domain, writer) 176 continue 177 } 178 } 179 180 if err := connWriter.WriteMessage(b); err != nil { 181 return err 182 } 183 } 184 } 185 186 response := func() error { 187 for { 188 b, err := connReader.ReadMessage() 189 if err == io.EOF { 190 return nil 191 } 192 193 if err != nil { 194 return err 195 } 196 197 if err := writer.WriteMessage(b); err != nil { 198 return err 199 } 200 } 201 } 202 203 if err := task.Run(ctx, request, response); err != nil { 204 return newError("connection ends").Base(err) 205 } 206 207 return nil 208 } 209 210 func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { 211 var ips []net.IP 212 var err error 213 214 switch qType { 215 case dnsmessage.TypeA: 216 ips, err = h.ipv4Lookup.LookupIPv4(domain) 217 case dnsmessage.TypeAAAA: 218 ips, err = h.ipv6Lookup.LookupIPv6(domain) 219 } 220 221 rcode := dns.RCodeFromError(err) 222 if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse { 223 newError("ip query").Base(err).WriteToLog() 224 return 225 } 226 227 b := buf.New() 228 rawBytes := b.Extend(buf.Size) 229 builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ 230 ID: id, 231 RCode: dnsmessage.RCode(rcode), 232 RecursionAvailable: true, 233 RecursionDesired: true, 234 Response: true, 235 Authoritative: true, 236 }) 237 builder.EnableCompression() 238 common.Must(builder.StartQuestions()) 239 common.Must(builder.Question(dnsmessage.Question{ 240 Name: dnsmessage.MustNewName(domain), 241 Class: dnsmessage.ClassINET, 242 Type: qType, 243 })) 244 common.Must(builder.StartAnswers()) 245 246 rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: 600} 247 for _, ip := range ips { 248 if len(ip) == net.IPv4len { 249 var r dnsmessage.AResource 250 copy(r.A[:], ip) 251 common.Must(builder.AResource(rHeader, r)) 252 } else { 253 var r dnsmessage.AAAAResource 254 copy(r.AAAA[:], ip) 255 common.Must(builder.AAAAResource(rHeader, r)) 256 } 257 } 258 msgBytes, err := builder.Finish() 259 if err != nil { 260 newError("pack message").Base(err).WriteToLog() 261 b.Release() 262 return 263 } 264 b.Resize(0, int32(len(msgBytes))) 265 266 if err := writer.WriteMessage(b); err != nil { 267 newError("write IP answer").Base(err).WriteToLog() 268 } 269 } 270 271 type outboundConn struct { 272 access sync.Mutex 273 dialer func() (internet.Connection, error) 274 275 conn net.Conn 276 connReady chan struct{} 277 } 278 279 func (c *outboundConn) dial() error { 280 conn, err := c.dialer() 281 if err != nil { 282 return err 283 } 284 c.conn = conn 285 c.connReady <- struct{}{} 286 return nil 287 } 288 289 func (c *outboundConn) Write(b []byte) (int, error) { 290 c.access.Lock() 291 292 if c.conn == nil { 293 if err := c.dial(); err != nil { 294 c.access.Unlock() 295 newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog() 296 return len(b), nil 297 } 298 } 299 300 c.access.Unlock() 301 302 return c.conn.Write(b) 303 } 304 305 func (c *outboundConn) Read(b []byte) (int, error) { 306 var conn net.Conn 307 c.access.Lock() 308 conn = c.conn 309 c.access.Unlock() 310 311 if conn == nil { 312 _, open := <-c.connReady 313 if !open { 314 return 0, io.EOF 315 } 316 conn = c.conn 317 } 318 319 return conn.Read(b) 320 } 321 322 func (c *outboundConn) Close() error { 323 c.access.Lock() 324 close(c.connReady) 325 if c.conn != nil { 326 c.conn.Close() 327 } 328 c.access.Unlock() 329 return nil 330 }