github.com/imannamdari/v2ray-core/v5@v5.0.5/proxy/dns/dns.go (about) 1 package dns 2 3 import ( 4 "context" 5 "io" 6 "sync" 7 "time" 8 9 "golang.org/x/net/dns/dnsmessage" 10 11 core "github.com/imannamdari/v2ray-core/v5" 12 "github.com/imannamdari/v2ray-core/v5/common" 13 "github.com/imannamdari/v2ray-core/v5/common/buf" 14 "github.com/imannamdari/v2ray-core/v5/common/net" 15 dns_proto "github.com/imannamdari/v2ray-core/v5/common/protocol/dns" 16 "github.com/imannamdari/v2ray-core/v5/common/session" 17 "github.com/imannamdari/v2ray-core/v5/common/signal" 18 "github.com/imannamdari/v2ray-core/v5/common/strmatcher" 19 "github.com/imannamdari/v2ray-core/v5/common/task" 20 "github.com/imannamdari/v2ray-core/v5/features/dns" 21 "github.com/imannamdari/v2ray-core/v5/features/policy" 22 "github.com/imannamdari/v2ray-core/v5/transport" 23 "github.com/imannamdari/v2ray-core/v5/transport/internet" 24 ) 25 26 func init() { 27 common.Must(common.RegisterConfig((*Config)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 28 h := new(Handler) 29 if err := core.RequireFeatures(ctx, func(dnsClient dns.Client, policyManager policy.Manager) error { 30 return h.Init(config.(*Config), dnsClient, policyManager) 31 }); err != nil { 32 return nil, err 33 } 34 return h, nil 35 })) 36 37 common.Must(common.RegisterConfig((*SimplifiedConfig)(nil), func(ctx context.Context, config interface{}) (interface{}, error) { 38 simplifiedServer := config.(*SimplifiedConfig) 39 _ = simplifiedServer 40 fullConfig := &Config{} 41 return common.CreateObject(ctx, fullConfig) 42 })) 43 } 44 45 type ownLinkVerifier interface { 46 IsOwnLink(ctx context.Context) bool 47 } 48 49 type Handler struct { 50 client dns.Client 51 ipv4Lookup dns.IPv4Lookup 52 ipv6Lookup dns.IPv6Lookup 53 ownLinkVerifier ownLinkVerifier 54 server net.Destination 55 timeout time.Duration 56 } 57 58 func (h *Handler) Init(config *Config, dnsClient dns.Client, policyManager policy.Manager) error { 59 // Enable FakeDNS for DNS outbound 60 if clientWithFakeDNS, ok := dnsClient.(dns.ClientWithFakeDNS); ok { 61 dnsClient = clientWithFakeDNS.AsFakeDNSClient() 62 } 63 h.client = dnsClient 64 h.timeout = policyManager.ForLevel(config.UserLevel).Timeouts.ConnectionIdle 65 66 if ipv4lookup, ok := dnsClient.(dns.IPv4Lookup); ok { 67 h.ipv4Lookup = ipv4lookup 68 } else { 69 return newError("dns.Client doesn't implement IPv4Lookup") 70 } 71 72 if ipv6lookup, ok := dnsClient.(dns.IPv6Lookup); ok { 73 h.ipv6Lookup = ipv6lookup 74 } else { 75 return newError("dns.Client doesn't implement IPv6Lookup") 76 } 77 78 if v, ok := dnsClient.(ownLinkVerifier); ok { 79 h.ownLinkVerifier = v 80 } 81 82 if config.Server != nil { 83 h.server = config.Server.AsDestination() 84 } 85 return nil 86 } 87 88 func (h *Handler) isOwnLink(ctx context.Context) bool { 89 return h.ownLinkVerifier != nil && h.ownLinkVerifier.IsOwnLink(ctx) 90 } 91 92 func parseIPQuery(b []byte) (r bool, domain string, id uint16, qType dnsmessage.Type) { 93 var parser dnsmessage.Parser 94 header, err := parser.Start(b) 95 if err != nil { 96 newError("parser start").Base(err).WriteToLog() 97 return 98 } 99 100 id = header.ID 101 q, err := parser.Question() 102 if err != nil { 103 newError("question").Base(err).WriteToLog() 104 return 105 } 106 qType = q.Type 107 if qType != dnsmessage.TypeA && qType != dnsmessage.TypeAAAA { 108 return 109 } 110 111 domain = q.Name.String() 112 r = true 113 return 114 } 115 116 // Process implements proxy.Outbound. 117 func (h *Handler) Process(ctx context.Context, link *transport.Link, d internet.Dialer) error { 118 outbound := session.OutboundFromContext(ctx) 119 if outbound == nil || !outbound.Target.IsValid() { 120 return newError("invalid outbound") 121 } 122 123 srcNetwork := outbound.Target.Network 124 125 dest := outbound.Target 126 if h.server.Network != net.Network_Unknown { 127 dest.Network = h.server.Network 128 } 129 if h.server.Address != nil { 130 dest.Address = h.server.Address 131 } 132 if h.server.Port != 0 { 133 dest.Port = h.server.Port 134 } 135 136 newError("handling DNS traffic to ", dest).WriteToLog(session.ExportIDToError(ctx)) 137 138 conn := &outboundConn{ 139 dialer: func() (internet.Connection, error) { 140 return d.Dial(ctx, dest) 141 }, 142 connReady: make(chan struct{}, 1), 143 } 144 145 var reader dns_proto.MessageReader 146 var writer dns_proto.MessageWriter 147 if srcNetwork == net.Network_TCP { 148 reader = dns_proto.NewTCPReader(link.Reader) 149 writer = &dns_proto.TCPWriter{ 150 Writer: link.Writer, 151 } 152 } else { 153 reader = &dns_proto.UDPReader{ 154 Reader: link.Reader, 155 } 156 writer = &dns_proto.UDPWriter{ 157 Writer: link.Writer, 158 } 159 } 160 161 var connReader dns_proto.MessageReader 162 var connWriter dns_proto.MessageWriter 163 if dest.Network == net.Network_TCP { 164 connReader = dns_proto.NewTCPReader(buf.NewReader(conn)) 165 connWriter = &dns_proto.TCPWriter{ 166 Writer: buf.NewWriter(conn), 167 } 168 } else { 169 connReader = &dns_proto.UDPReader{ 170 Reader: buf.NewPacketReader(conn), 171 } 172 connWriter = &dns_proto.UDPWriter{ 173 Writer: buf.NewWriter(conn), 174 } 175 } 176 177 ctx, cancel := context.WithCancel(ctx) 178 timer := signal.CancelAfterInactivity(ctx, cancel, h.timeout) 179 180 request := func() error { 181 defer conn.Close() 182 183 for { 184 b, err := reader.ReadMessage() 185 if err == io.EOF { 186 return nil 187 } 188 189 if err != nil { 190 return err 191 } 192 193 timer.Update() 194 195 if !h.isOwnLink(ctx) { 196 isIPQuery, domain, id, qType := parseIPQuery(b.Bytes()) 197 if isIPQuery { 198 if domain, err := strmatcher.ToDomain(domain); err == nil { 199 go h.handleIPQuery(id, qType, domain, writer) 200 continue 201 } 202 } 203 } 204 205 if err := connWriter.WriteMessage(b); err != nil { 206 return err 207 } 208 } 209 } 210 211 response := func() error { 212 for { 213 b, err := connReader.ReadMessage() 214 if err == io.EOF { 215 return nil 216 } 217 218 if err != nil { 219 return err 220 } 221 222 timer.Update() 223 224 if err := writer.WriteMessage(b); err != nil { 225 return err 226 } 227 } 228 } 229 230 if err := task.Run(ctx, request, response); err != nil { 231 return newError("connection ends").Base(err) 232 } 233 234 return nil 235 } 236 237 func (h *Handler) handleIPQuery(id uint16, qType dnsmessage.Type, domain string, writer dns_proto.MessageWriter) { 238 var ips []net.IP 239 var err error 240 241 var ttl uint32 = 600 242 243 switch qType { 244 case dnsmessage.TypeA: 245 ips, err = h.ipv4Lookup.LookupIPv4(domain) 246 case dnsmessage.TypeAAAA: 247 ips, err = h.ipv6Lookup.LookupIPv6(domain) 248 } 249 250 rcode := dns.RCodeFromError(err) 251 if rcode == 0 && len(ips) == 0 && err != dns.ErrEmptyResponse { 252 newError("ip query").Base(err).WriteToLog() 253 return 254 } 255 256 b := buf.New() 257 rawBytes := b.Extend(buf.Size) 258 builder := dnsmessage.NewBuilder(rawBytes[:0], dnsmessage.Header{ 259 ID: id, 260 RCode: dnsmessage.RCode(rcode), 261 RecursionAvailable: true, 262 RecursionDesired: true, 263 Response: true, 264 }) 265 builder.EnableCompression() 266 common.Must(builder.StartQuestions()) 267 common.Must(builder.Question(dnsmessage.Question{ 268 Name: dnsmessage.MustNewName(domain), 269 Class: dnsmessage.ClassINET, 270 Type: qType, 271 })) 272 common.Must(builder.StartAnswers()) 273 274 rHeader := dnsmessage.ResourceHeader{Name: dnsmessage.MustNewName(domain), Class: dnsmessage.ClassINET, TTL: ttl} 275 for _, ip := range ips { 276 if len(ip) == net.IPv4len { 277 var r dnsmessage.AResource 278 copy(r.A[:], ip) 279 common.Must(builder.AResource(rHeader, r)) 280 } else { 281 var r dnsmessage.AAAAResource 282 copy(r.AAAA[:], ip) 283 common.Must(builder.AAAAResource(rHeader, r)) 284 } 285 } 286 msgBytes, err := builder.Finish() 287 if err != nil { 288 newError("pack message").Base(err).WriteToLog() 289 b.Release() 290 return 291 } 292 b.Resize(0, int32(len(msgBytes))) 293 294 if err := writer.WriteMessage(b); err != nil { 295 newError("write IP answer").Base(err).WriteToLog() 296 } 297 } 298 299 type outboundConn struct { 300 access sync.Mutex 301 dialer func() (internet.Connection, error) 302 303 conn net.Conn 304 connReady chan struct{} 305 } 306 307 func (c *outboundConn) dial() error { 308 conn, err := c.dialer() 309 if err != nil { 310 return err 311 } 312 c.conn = conn 313 c.connReady <- struct{}{} 314 return nil 315 } 316 317 func (c *outboundConn) Write(b []byte) (int, error) { 318 c.access.Lock() 319 320 if c.conn == nil { 321 if err := c.dial(); err != nil { 322 c.access.Unlock() 323 newError("failed to dial outbound connection").Base(err).AtWarning().WriteToLog() 324 return len(b), nil 325 } 326 } 327 328 c.access.Unlock() 329 330 return c.conn.Write(b) 331 } 332 333 func (c *outboundConn) Read(b []byte) (int, error) { 334 var conn net.Conn 335 c.access.Lock() 336 conn = c.conn 337 c.access.Unlock() 338 339 if conn == nil { 340 _, open := <-c.connReady 341 if !open { 342 return 0, io.EOF 343 } 344 conn = c.conn 345 } 346 347 return conn.Read(b) 348 } 349 350 func (c *outboundConn) Close() error { 351 c.access.Lock() 352 close(c.connReady) 353 if c.conn != nil { 354 c.conn.Close() 355 } 356 c.access.Unlock() 357 return nil 358 }