github.com/xtls/xray-core@v1.8.12-0.20240518155711-3168d27b0bdb/proxy/dns/dns_test.go (about)

     1  package dns_test
     2  
     3  import (
     4  	"strconv"
     5  	"testing"
     6  	"time"
     7  
     8  	"github.com/google/go-cmp/cmp"
     9  	"github.com/miekg/dns"
    10  	"github.com/xtls/xray-core/app/dispatcher"
    11  	dnsapp "github.com/xtls/xray-core/app/dns"
    12  	"github.com/xtls/xray-core/app/policy"
    13  	"github.com/xtls/xray-core/app/proxyman"
    14  	_ "github.com/xtls/xray-core/app/proxyman/inbound"
    15  	_ "github.com/xtls/xray-core/app/proxyman/outbound"
    16  	"github.com/xtls/xray-core/common"
    17  	"github.com/xtls/xray-core/common/net"
    18  	"github.com/xtls/xray-core/common/serial"
    19  	"github.com/xtls/xray-core/core"
    20  	dns_proxy "github.com/xtls/xray-core/proxy/dns"
    21  	"github.com/xtls/xray-core/proxy/dokodemo"
    22  	"github.com/xtls/xray-core/testing/servers/tcp"
    23  	"github.com/xtls/xray-core/testing/servers/udp"
    24  )
    25  
    26  type staticHandler struct{}
    27  
    28  func (*staticHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
    29  	ans := new(dns.Msg)
    30  	ans.Id = r.Id
    31  
    32  	var clientIP net.IP
    33  
    34  	opt := r.IsEdns0()
    35  	if opt != nil {
    36  		for _, o := range opt.Option {
    37  			if o.Option() == dns.EDNS0SUBNET {
    38  				subnet := o.(*dns.EDNS0_SUBNET)
    39  				clientIP = subnet.Address
    40  			}
    41  		}
    42  	}
    43  
    44  	for _, q := range r.Question {
    45  		switch {
    46  		case q.Name == "google.com." && q.Qtype == dns.TypeA:
    47  			if clientIP == nil {
    48  				rr, _ := dns.NewRR("google.com. IN A 8.8.8.8")
    49  				ans.Answer = append(ans.Answer, rr)
    50  			} else {
    51  				rr, _ := dns.NewRR("google.com. IN A 8.8.4.4")
    52  				ans.Answer = append(ans.Answer, rr)
    53  			}
    54  
    55  		case q.Name == "facebook.com." && q.Qtype == dns.TypeA:
    56  			rr, _ := dns.NewRR("facebook.com. IN A 9.9.9.9")
    57  			ans.Answer = append(ans.Answer, rr)
    58  
    59  		case q.Name == "ipv6.google.com." && q.Qtype == dns.TypeA:
    60  			rr, err := dns.NewRR("ipv6.google.com. IN A 8.8.8.7")
    61  			common.Must(err)
    62  			ans.Answer = append(ans.Answer, rr)
    63  
    64  		case q.Name == "ipv6.google.com." && q.Qtype == dns.TypeAAAA:
    65  			rr, err := dns.NewRR("ipv6.google.com. IN AAAA 2001:4860:4860::8888")
    66  			common.Must(err)
    67  			ans.Answer = append(ans.Answer, rr)
    68  
    69  		case q.Name == "notexist.google.com." && q.Qtype == dns.TypeAAAA:
    70  			ans.MsgHdr.Rcode = dns.RcodeNameError
    71  		}
    72  	}
    73  	w.WriteMsg(ans)
    74  }
    75  
    76  func TestUDPDNSTunnel(t *testing.T) {
    77  	port := udp.PickPort()
    78  
    79  	dnsServer := dns.Server{
    80  		Addr:    "127.0.0.1:" + port.String(),
    81  		Net:     "udp",
    82  		Handler: &staticHandler{},
    83  		UDPSize: 1200,
    84  	}
    85  	defer dnsServer.Shutdown()
    86  
    87  	go dnsServer.ListenAndServe()
    88  	time.Sleep(time.Second)
    89  
    90  	serverPort := udp.PickPort()
    91  	config := &core.Config{
    92  		App: []*serial.TypedMessage{
    93  			serial.ToTypedMessage(&dnsapp.Config{
    94  				NameServers: []*net.Endpoint{
    95  					{
    96  						Network: net.Network_UDP,
    97  						Address: &net.IPOrDomain{
    98  							Address: &net.IPOrDomain_Ip{
    99  								Ip: []byte{127, 0, 0, 1},
   100  							},
   101  						},
   102  						Port: uint32(port),
   103  					},
   104  				},
   105  			}),
   106  			serial.ToTypedMessage(&dispatcher.Config{}),
   107  			serial.ToTypedMessage(&proxyman.OutboundConfig{}),
   108  			serial.ToTypedMessage(&proxyman.InboundConfig{}),
   109  			serial.ToTypedMessage(&policy.Config{}),
   110  		},
   111  		Inbound: []*core.InboundHandlerConfig{
   112  			{
   113  				ProxySettings: serial.ToTypedMessage(&dokodemo.Config{
   114  					Address:  net.NewIPOrDomain(net.LocalHostIP),
   115  					Port:     uint32(port),
   116  					Networks: []net.Network{net.Network_UDP},
   117  				}),
   118  				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
   119  					PortList: &net.PortList{Range: []*net.PortRange{net.SinglePortRange(serverPort)}},
   120  					Listen:   net.NewIPOrDomain(net.LocalHostIP),
   121  				}),
   122  			},
   123  		},
   124  		Outbound: []*core.OutboundHandlerConfig{
   125  			{
   126  				ProxySettings: serial.ToTypedMessage(&dns_proxy.Config{}),
   127  			},
   128  		},
   129  	}
   130  
   131  	v, err := core.New(config)
   132  	common.Must(err)
   133  	common.Must(v.Start())
   134  	defer v.Close()
   135  
   136  	{
   137  		m1 := new(dns.Msg)
   138  		m1.Id = dns.Id()
   139  		m1.RecursionDesired = true
   140  		m1.Question = make([]dns.Question, 1)
   141  		m1.Question[0] = dns.Question{Name: "google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
   142  
   143  		c := new(dns.Client)
   144  		in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
   145  		common.Must(err)
   146  
   147  		if len(in.Answer) != 1 {
   148  			t.Fatal("len(answer): ", len(in.Answer))
   149  		}
   150  
   151  		rr, ok := in.Answer[0].(*dns.A)
   152  		if !ok {
   153  			t.Fatal("not A record")
   154  		}
   155  		if r := cmp.Diff(rr.A[:], net.IP{8, 8, 8, 8}); r != "" {
   156  			t.Error(r)
   157  		}
   158  	}
   159  
   160  	{
   161  		m1 := new(dns.Msg)
   162  		m1.Id = dns.Id()
   163  		m1.RecursionDesired = true
   164  		m1.Question = make([]dns.Question, 1)
   165  		m1.Question[0] = dns.Question{Name: "ipv4only.google.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
   166  
   167  		c := new(dns.Client)
   168  		c.Timeout = 10 * time.Second
   169  		in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
   170  		common.Must(err)
   171  
   172  		if len(in.Answer) != 0 {
   173  			t.Fatal("len(answer): ", len(in.Answer))
   174  		}
   175  	}
   176  
   177  	{
   178  		m1 := new(dns.Msg)
   179  		m1.Id = dns.Id()
   180  		m1.RecursionDesired = true
   181  		m1.Question = make([]dns.Question, 1)
   182  		m1.Question[0] = dns.Question{Name: "notexist.google.com.", Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
   183  
   184  		c := new(dns.Client)
   185  		in, _, err := c.Exchange(m1, "127.0.0.1:"+strconv.Itoa(int(serverPort)))
   186  		common.Must(err)
   187  
   188  		if in.Rcode != dns.RcodeNameError {
   189  			t.Error("expected NameError, but got ", in.Rcode)
   190  		}
   191  	}
   192  }
   193  
   194  func TestTCPDNSTunnel(t *testing.T) {
   195  	port := udp.PickPort()
   196  
   197  	dnsServer := dns.Server{
   198  		Addr:    "127.0.0.1:" + port.String(),
   199  		Net:     "udp",
   200  		Handler: &staticHandler{},
   201  	}
   202  	defer dnsServer.Shutdown()
   203  
   204  	go dnsServer.ListenAndServe()
   205  	time.Sleep(time.Second)
   206  
   207  	serverPort := tcp.PickPort()
   208  	config := &core.Config{
   209  		App: []*serial.TypedMessage{
   210  			serial.ToTypedMessage(&dnsapp.Config{
   211  				NameServer: []*dnsapp.NameServer{
   212  					{
   213  						Address: &net.Endpoint{
   214  							Network: net.Network_UDP,
   215  							Address: &net.IPOrDomain{
   216  								Address: &net.IPOrDomain_Ip{
   217  									Ip: []byte{127, 0, 0, 1},
   218  								},
   219  							},
   220  							Port: uint32(port),
   221  						},
   222  					},
   223  				},
   224  			}),
   225  			serial.ToTypedMessage(&dispatcher.Config{}),
   226  			serial.ToTypedMessage(&proxyman.OutboundConfig{}),
   227  			serial.ToTypedMessage(&proxyman.InboundConfig{}),
   228  			serial.ToTypedMessage(&policy.Config{}),
   229  		},
   230  		Inbound: []*core.InboundHandlerConfig{
   231  			{
   232  				ProxySettings: serial.ToTypedMessage(&dokodemo.Config{
   233  					Address:  net.NewIPOrDomain(net.LocalHostIP),
   234  					Port:     uint32(port),
   235  					Networks: []net.Network{net.Network_TCP},
   236  				}),
   237  				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
   238  					PortList: &net.PortList{Range: []*net.PortRange{net.SinglePortRange(serverPort)}},
   239  					Listen:   net.NewIPOrDomain(net.LocalHostIP),
   240  				}),
   241  			},
   242  		},
   243  		Outbound: []*core.OutboundHandlerConfig{
   244  			{
   245  				ProxySettings: serial.ToTypedMessage(&dns_proxy.Config{}),
   246  			},
   247  		},
   248  	}
   249  
   250  	v, err := core.New(config)
   251  	common.Must(err)
   252  	common.Must(v.Start())
   253  	defer v.Close()
   254  
   255  	m1 := new(dns.Msg)
   256  	m1.Id = dns.Id()
   257  	m1.RecursionDesired = true
   258  	m1.Question = make([]dns.Question, 1)
   259  	m1.Question[0] = dns.Question{Name: "google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
   260  
   261  	c := &dns.Client{
   262  		Net: "tcp",
   263  	}
   264  	in, _, err := c.Exchange(m1, "127.0.0.1:"+serverPort.String())
   265  	common.Must(err)
   266  
   267  	if len(in.Answer) != 1 {
   268  		t.Fatal("len(answer): ", len(in.Answer))
   269  	}
   270  
   271  	rr, ok := in.Answer[0].(*dns.A)
   272  	if !ok {
   273  		t.Fatal("not A record")
   274  	}
   275  	if r := cmp.Diff(rr.A[:], net.IP{8, 8, 8, 8}); r != "" {
   276  		t.Error(r)
   277  	}
   278  }
   279  
   280  func TestUDP2TCPDNSTunnel(t *testing.T) {
   281  	port := tcp.PickPort()
   282  
   283  	dnsServer := dns.Server{
   284  		Addr:    "127.0.0.1:" + port.String(),
   285  		Net:     "tcp",
   286  		Handler: &staticHandler{},
   287  	}
   288  	defer dnsServer.Shutdown()
   289  
   290  	go dnsServer.ListenAndServe()
   291  	time.Sleep(time.Second)
   292  
   293  	serverPort := tcp.PickPort()
   294  	config := &core.Config{
   295  		App: []*serial.TypedMessage{
   296  			serial.ToTypedMessage(&dnsapp.Config{
   297  				NameServer: []*dnsapp.NameServer{
   298  					{
   299  						Address: &net.Endpoint{
   300  							Network: net.Network_UDP,
   301  							Address: &net.IPOrDomain{
   302  								Address: &net.IPOrDomain_Ip{
   303  									Ip: []byte{127, 0, 0, 1},
   304  								},
   305  							},
   306  							Port: uint32(port),
   307  						},
   308  					},
   309  				},
   310  			}),
   311  			serial.ToTypedMessage(&dispatcher.Config{}),
   312  			serial.ToTypedMessage(&proxyman.OutboundConfig{}),
   313  			serial.ToTypedMessage(&proxyman.InboundConfig{}),
   314  			serial.ToTypedMessage(&policy.Config{}),
   315  		},
   316  		Inbound: []*core.InboundHandlerConfig{
   317  			{
   318  				ProxySettings: serial.ToTypedMessage(&dokodemo.Config{
   319  					Address:  net.NewIPOrDomain(net.LocalHostIP),
   320  					Port:     uint32(port),
   321  					Networks: []net.Network{net.Network_TCP},
   322  				}),
   323  				ReceiverSettings: serial.ToTypedMessage(&proxyman.ReceiverConfig{
   324  					PortList: &net.PortList{Range: []*net.PortRange{net.SinglePortRange(serverPort)}},
   325  					Listen:   net.NewIPOrDomain(net.LocalHostIP),
   326  				}),
   327  			},
   328  		},
   329  		Outbound: []*core.OutboundHandlerConfig{
   330  			{
   331  				ProxySettings: serial.ToTypedMessage(&dns_proxy.Config{
   332  					Server: &net.Endpoint{
   333  						Network: net.Network_TCP,
   334  					},
   335  				}),
   336  			},
   337  		},
   338  	}
   339  
   340  	v, err := core.New(config)
   341  	common.Must(err)
   342  	common.Must(v.Start())
   343  	defer v.Close()
   344  
   345  	m1 := new(dns.Msg)
   346  	m1.Id = dns.Id()
   347  	m1.RecursionDesired = true
   348  	m1.Question = make([]dns.Question, 1)
   349  	m1.Question[0] = dns.Question{Name: "google.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
   350  
   351  	c := &dns.Client{
   352  		Net: "tcp",
   353  	}
   354  	in, _, err := c.Exchange(m1, "127.0.0.1:"+serverPort.String())
   355  	common.Must(err)
   356  
   357  	if len(in.Answer) != 1 {
   358  		t.Fatal("len(answer): ", len(in.Answer))
   359  	}
   360  
   361  	rr, ok := in.Answer[0].(*dns.A)
   362  	if !ok {
   363  		t.Fatal("not A record")
   364  	}
   365  	if r := cmp.Diff(rr.A[:], net.IP{8, 8, 8, 8}); r != "" {
   366  		t.Error(r)
   367  	}
   368  }