github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/tun/tun_test.go (about)

     1  /*
     2   * Copyright (c) 2017, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package tun
    21  
    22  import (
    23  	"bytes"
    24  	"errors"
    25  	"fmt"
    26  	"io"
    27  	"math/rand"
    28  	"net"
    29  	"os"
    30  	"strconv"
    31  	"sync"
    32  	"sync/atomic"
    33  	"syscall"
    34  	"testing"
    35  	"time"
    36  
    37  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    38  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
    39  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/stacktrace"
    40  	"github.com/miekg/dns"
    41  )
    42  
    43  const (
    44  	UNIX_DOMAIN_SOCKET_NAME = "/tmp/tun_test.sock"
    45  	SESSION_ID_LENGTH       = 16
    46  	TCP_PORT                = 8000
    47  	TCP_RELAY_CHUNK_SIZE    = int64(65536)
    48  	TCP_RELAY_TOTAL_SIZE    = int64(1073741824)
    49  	CONCURRENT_CLIENT_COUNT = 5
    50  	PACKET_METRICS_TIMEOUT  = 10 * time.Second
    51  )
    52  
    53  func TestTunneledTCPIPv4(t *testing.T) {
    54  	testTunneledTCP(t, false)
    55  }
    56  
    57  func TestTunneledTCPIPv6(t *testing.T) {
    58  	testTunneledTCP(t, true)
    59  }
    60  
    61  func TestSessionExpiry(t *testing.T) {
    62  	t.Skip("TODO: test short session TTLs actually persist/expire as expected")
    63  }
    64  
    65  func TestTrafficRules(t *testing.T) {
    66  	t.Skip("TODO: negative tests for checkAllowedTCPPortFunc, checkAllowedUDPPortFunc")
    67  }
    68  
    69  func TestResetRouting(t *testing.T) {
    70  	t.Skip("TODO: test conntrack delete effectiveness")
    71  }
    72  
    73  func testTunneledTCP(t *testing.T, useIPv6 bool) {
    74  
    75  	// This test harness does the following:
    76  	//
    77  	// - starts a TCP server; this server echoes the data it receives
    78  	// - starts a packet tunnel server that uses a unix domain socket for client channels
    79  	// - starts CONCURRENT_CLIENT_COUNT concurrent clients
    80  	// - each client runs a packet tunnel client connected to the server unix domain socket
    81  	// - one client first performs a tunneled DNS query against an external DNS server
    82  	// - clients establish a TCP client connection to the TCP server through the packet tunnel
    83  	// - each TCP client transfers TCP_RELAY_TOTAL_SIZE bytes to the TCP server
    84  	// - the test checks that all data echoes back correctly and that the server packet
    85  	//   metrics reflects the expected amount of data transferred through the tunnel
    86  	// - the test also checks that the flow activity updater mechanism correctly reports
    87  	//   the total bytes transferred
    88  	// - this test runs in either IPv4 or IPv6 mode
    89  	// - the test host's public IP address is used as the TCP server IP address; it is
    90  	//   expected that the server tun device will NAT to the public interface; clients
    91  	//   use SO_BINDTODEVICE/IP_BOUND_IF to force the TCP client connections through the
    92  	//   tunnel
    93  	//
    94  	// Note: this test can modify host network configuration; in addition to tun device
    95  	// and routing config, see the changes made in fixBindToDevice.
    96  
    97  	if TCP_RELAY_TOTAL_SIZE%TCP_RELAY_CHUNK_SIZE != 0 {
    98  		t.Fatalf("invalid relay size")
    99  	}
   100  
   101  	MTU := DEFAULT_MTU
   102  
   103  	testTCPServer, err := startTestTCPServer(useIPv6)
   104  	if err != nil {
   105  		if err == errNoIPAddress {
   106  			t.Skipf("test unsupported: %s", errNoIPAddress)
   107  		}
   108  		t.Fatalf("startTestTCPServer failed: %s", err)
   109  	}
   110  
   111  	var flowCounter bytesTransferredCounter
   112  
   113  	flowActivityUpdaterMaker := func(_ string, IPAddress net.IP) []FlowActivityUpdater {
   114  
   115  		if IPAddress.String() != testTCPServer.getListenerIPAddress() {
   116  			t.Fatalf("unexpected flow IP address")
   117  		}
   118  
   119  		return []FlowActivityUpdater{&flowCounter}
   120  	}
   121  
   122  	var metricsCounter bytesTransferredCounter
   123  
   124  	metricsUpdater := func(TCPApplicationBytesDown, TCPApplicationBytesUp, _, _ int64) {
   125  		metricsCounter.UpdateProgress(
   126  			TCPApplicationBytesDown, TCPApplicationBytesUp, 0)
   127  	}
   128  
   129  	testServer, err := startTestServer(useIPv6, MTU, flowActivityUpdaterMaker, metricsUpdater)
   130  	if err != nil {
   131  		t.Fatalf("startTestServer failed: %s", err)
   132  	}
   133  
   134  	results := make(chan error, CONCURRENT_CLIENT_COUNT)
   135  
   136  	for i := 0; i < CONCURRENT_CLIENT_COUNT; i++ {
   137  		go func(clientNum int) {
   138  
   139  			testClient, err := startTestClient(
   140  				useIPv6, MTU, []string{testTCPServer.getListenerIPAddress()})
   141  			if err != nil {
   142  				results <- fmt.Errorf("startTestClient failed: %s", err)
   143  				return
   144  			}
   145  
   146  			// Test one tunneled DNS query.
   147  
   148  			if clientNum == 0 {
   149  				err = testDNSClient(
   150  					useIPv6,
   151  					testClient.tunClient.device.Name())
   152  				if err != nil {
   153  					results <- fmt.Errorf("testDNSClient failed: %s", err)
   154  					return
   155  				}
   156  			}
   157  
   158  			// The TCP client will bind to the packet tunnel client tun
   159  			// device and connect to the TCP server. With the bind to
   160  			// device, TCP packets will flow through the packet tunnel
   161  			// client to the packet tunnel server, through the packet tunnel
   162  			// server's tun device, NATed to the server's public interface,
   163  			// and finally reaching the TCP server. All this happens on
   164  			// the single host running the test.
   165  
   166  			testTCPClient, err := startTestTCPClient(
   167  				testClient.tunClient.device.Name(),
   168  				testTCPServer.getListenerIPAddress())
   169  			if err != nil {
   170  				results <- fmt.Errorf("startTestTCPClient failed: %s", err)
   171  				return
   172  			}
   173  
   174  			// Send TCP_RELAY_TOTAL_SIZE random bytes to the TCP server, and
   175  			// check that it echoes back the same bytes.
   176  
   177  			sendChunk, receiveChunk := make([]byte, TCP_RELAY_CHUNK_SIZE), make([]byte, TCP_RELAY_CHUNK_SIZE)
   178  
   179  			for i := int64(0); i < TCP_RELAY_TOTAL_SIZE; i += TCP_RELAY_CHUNK_SIZE {
   180  
   181  				_, err := rand.Read(sendChunk)
   182  				if err != nil {
   183  					results <- fmt.Errorf("rand.Read failed: %s", err)
   184  					return
   185  				}
   186  
   187  				_, err = testTCPClient.Write(sendChunk)
   188  				if err != nil {
   189  					results <- fmt.Errorf("mockTCPClient.Write failed: %s", err)
   190  					return
   191  				}
   192  
   193  				_, err = io.ReadFull(testTCPClient, receiveChunk)
   194  				if err != nil {
   195  					results <- fmt.Errorf("io.ReadFull failed: %s", err)
   196  					return
   197  				}
   198  
   199  				if !bytes.Equal(sendChunk, receiveChunk) {
   200  					results <- fmt.Errorf("bytes.Equal failed")
   201  					return
   202  				}
   203  			}
   204  
   205  			testTCPClient.stop()
   206  
   207  			// Allow some time for the TCP FIN to be tunneled, for a clean shutdown.
   208  			time.Sleep(100 * time.Millisecond)
   209  
   210  			testClient.stop()
   211  
   212  			// Check metrics to ensure traffic was tunneled and metrics reported
   213  
   214  			// Note: this code does not ensure that the "last" packet metrics was
   215  			// for this very client; but all packet metrics should be the same.
   216  
   217  			packetMetricsFields := testServer.logger.getLastPacketMetrics()
   218  
   219  			if packetMetricsFields == nil {
   220  				results <- fmt.Errorf("testServer.logger.getLastPacketMetrics failed")
   221  				return
   222  			}
   223  
   224  			expectedFields := []struct {
   225  				nameSuffix   string
   226  				minimumValue int64
   227  			}{
   228  				{"packets_up", TCP_RELAY_TOTAL_SIZE / int64(MTU)},
   229  				{"packets_down", TCP_RELAY_TOTAL_SIZE / int64(MTU)},
   230  				{"bytes_up", TCP_RELAY_TOTAL_SIZE},
   231  				{"bytes_down", TCP_RELAY_TOTAL_SIZE},
   232  				{"application_bytes_up", TCP_RELAY_TOTAL_SIZE},
   233  				{"application_bytes_down", TCP_RELAY_TOTAL_SIZE},
   234  			}
   235  
   236  			for _, expectedField := range expectedFields {
   237  				var name string
   238  				if useIPv6 {
   239  					name = "tcp_ipv6_" + expectedField.nameSuffix
   240  				} else {
   241  					name = "tcp_ipv4_" + expectedField.nameSuffix
   242  				}
   243  				field, ok := packetMetricsFields[name]
   244  				if !ok {
   245  					results <- fmt.Errorf("missing expected metric field: %s", name)
   246  					return
   247  				}
   248  				value, ok := field.(int64)
   249  				if !ok {
   250  					results <- fmt.Errorf("unexpected metric field type: %s", name)
   251  					return
   252  				}
   253  				if value < expectedField.minimumValue {
   254  					results <- fmt.Errorf("unexpected metric field value: %s: %d", name, value)
   255  					return
   256  				}
   257  			}
   258  
   259  			results <- nil
   260  		}(i)
   261  	}
   262  
   263  	for i := 0; i < CONCURRENT_CLIENT_COUNT; i++ {
   264  		result := <-results
   265  		if result != nil {
   266  			t.Fatalf(result.Error())
   267  		}
   268  	}
   269  
   270  	// Note: reported bytes transferred can exceed expected bytes
   271  	// transferred due to retransmission of packets.
   272  
   273  	expectedBytesTransferred := CONCURRENT_CLIENT_COUNT * TCP_RELAY_TOTAL_SIZE
   274  
   275  	downstreamBytesTransferred, upstreamBytesTransferred, _ := flowCounter.Get()
   276  	if downstreamBytesTransferred < expectedBytesTransferred {
   277  		t.Fatalf("unexpected flow downstreamBytesTransferred: %d; expected at least %d",
   278  			downstreamBytesTransferred, expectedBytesTransferred)
   279  	}
   280  	if upstreamBytesTransferred < expectedBytesTransferred {
   281  		t.Fatalf("unexpected flow upstreamBytesTransferred: %d; expected at least %d",
   282  			upstreamBytesTransferred, expectedBytesTransferred)
   283  	}
   284  
   285  	downstreamBytesTransferred, upstreamBytesTransferred, _ = metricsCounter.Get()
   286  	if downstreamBytesTransferred < expectedBytesTransferred {
   287  		t.Fatalf("unexpected metrics downstreamBytesTransferred: %d; expected at least %d",
   288  			downstreamBytesTransferred, expectedBytesTransferred)
   289  	}
   290  	if upstreamBytesTransferred < expectedBytesTransferred {
   291  		t.Fatalf("unexpected metrics upstreamBytesTransferred: %d; expected at least %d",
   292  			upstreamBytesTransferred, expectedBytesTransferred)
   293  	}
   294  
   295  	testServer.stop()
   296  
   297  	testTCPServer.stop()
   298  }
   299  
   300  type bytesTransferredCounter struct {
   301  	// Note: 64-bit ints used with atomic operations are placed
   302  	// at the start of struct to ensure 64-bit alignment.
   303  	// (https://golang.org/pkg/sync/atomic/#pkg-note-BUG)
   304  	downstreamBytes     int64
   305  	upstreamBytes       int64
   306  	durationNanoseconds int64
   307  }
   308  
   309  func (counter *bytesTransferredCounter) UpdateProgress(
   310  	downstreamBytes, upstreamBytes int64, durationNanoseconds int64) {
   311  
   312  	atomic.AddInt64(&counter.downstreamBytes, downstreamBytes)
   313  	atomic.AddInt64(&counter.upstreamBytes, upstreamBytes)
   314  	atomic.AddInt64(&counter.durationNanoseconds, durationNanoseconds)
   315  }
   316  
   317  func (counter *bytesTransferredCounter) Get() (int64, int64, int64) {
   318  	return atomic.LoadInt64(&counter.downstreamBytes),
   319  		atomic.LoadInt64(&counter.upstreamBytes),
   320  		atomic.LoadInt64(&counter.durationNanoseconds)
   321  }
   322  
   323  type testServer struct {
   324  	logger         *testLogger
   325  	updaterMaker   FlowActivityUpdaterMaker
   326  	metricsUpdater MetricsUpdater
   327  	tunServer      *Server
   328  	unixListener   net.Listener
   329  	clientConns    *common.Conns
   330  	workers        *sync.WaitGroup
   331  }
   332  
   333  func startTestServer(
   334  	useIPv6 bool,
   335  	MTU int,
   336  	updaterMaker FlowActivityUpdaterMaker,
   337  	metricsUpdater MetricsUpdater) (*testServer, error) {
   338  
   339  	logger := newTestLogger(true)
   340  
   341  	getDNSResolverIPv4Addresses := func() []net.IP {
   342  		return []net.IP{net.ParseIP("8.8.8.8")}
   343  	}
   344  
   345  	getDNSResolverIPv6Addresses := func() []net.IP {
   346  		return []net.IP{net.ParseIP("2001:4860:4860::8888")}
   347  	}
   348  
   349  	config := &ServerConfig{
   350  		Logger:                          logger,
   351  		SudoNetworkConfigCommands:       os.Getenv("TUN_TEST_SUDO") != "",
   352  		AllowNoIPv6NetworkConfiguration: !useIPv6,
   353  		GetDNSResolverIPv4Addresses:     getDNSResolverIPv4Addresses,
   354  		GetDNSResolverIPv6Addresses:     getDNSResolverIPv6Addresses,
   355  		MTU:                             MTU,
   356  		AllowBogons:                     true,
   357  	}
   358  
   359  	tunServer, err := NewServer(config)
   360  	if err != nil {
   361  		return nil, fmt.Errorf("startTestServer(): NewServer failed: %s", err)
   362  	}
   363  
   364  	tunServer.Start()
   365  
   366  	_ = syscall.Unlink(UNIX_DOMAIN_SOCKET_NAME)
   367  
   368  	unixListener, err := net.Listen("unix", UNIX_DOMAIN_SOCKET_NAME)
   369  	if err != nil {
   370  		return nil, fmt.Errorf("startTestServer(): net.Listen failed: %s", err)
   371  	}
   372  
   373  	server := &testServer{
   374  		logger:         logger,
   375  		updaterMaker:   updaterMaker,
   376  		metricsUpdater: metricsUpdater,
   377  		tunServer:      tunServer,
   378  		unixListener:   unixListener,
   379  		clientConns:    common.NewConns(),
   380  		workers:        new(sync.WaitGroup),
   381  	}
   382  
   383  	server.workers.Add(1)
   384  	go server.run()
   385  
   386  	return server, nil
   387  }
   388  
   389  func (server *testServer) run() {
   390  	defer server.workers.Done()
   391  
   392  	for {
   393  		clientConn, err := server.unixListener.Accept()
   394  		if err != nil {
   395  			fmt.Printf("testServer.run(): unixListener.Accept failed: %s\n", err)
   396  			return
   397  		}
   398  
   399  		signalConn := newSignalConn(clientConn)
   400  
   401  		if !server.clientConns.Add(signalConn) {
   402  			return
   403  		}
   404  
   405  		server.workers.Add(1)
   406  		go func() {
   407  			defer server.workers.Done()
   408  			defer signalConn.Close()
   409  
   410  			sessionID := prng.HexString(SESSION_ID_LENGTH)
   411  
   412  			checkAllowedPortFunc := func(net.IP, int) bool { return true }
   413  			checkAllowedDomainFunc := func(string) bool { return true }
   414  
   415  			dnsQualityReporter := func(_ bool, _ time.Duration, _ net.IP) {}
   416  
   417  			server.tunServer.ClientConnected(
   418  				sessionID,
   419  				signalConn,
   420  				checkAllowedPortFunc,
   421  				checkAllowedPortFunc,
   422  				checkAllowedDomainFunc,
   423  				server.updaterMaker,
   424  				server.metricsUpdater,
   425  				dnsQualityReporter)
   426  
   427  			signalConn.Wait()
   428  
   429  			server.tunServer.ClientDisconnected(
   430  				sessionID)
   431  		}()
   432  	}
   433  }
   434  
   435  func (server *testServer) stop() {
   436  	server.clientConns.CloseAll()
   437  	server.unixListener.Close()
   438  	server.workers.Wait()
   439  	server.tunServer.Stop()
   440  }
   441  
   442  type signalConn struct {
   443  	net.Conn
   444  	ioErrorSignal chan struct{}
   445  }
   446  
   447  func newSignalConn(baseConn net.Conn) *signalConn {
   448  	return &signalConn{
   449  		Conn:          baseConn,
   450  		ioErrorSignal: make(chan struct{}, 1),
   451  	}
   452  }
   453  
   454  func (conn *signalConn) Read(p []byte) (n int, err error) {
   455  	n, err = conn.Conn.Read(p)
   456  	if err != nil {
   457  		_ = conn.Conn.Close()
   458  		select {
   459  		case conn.ioErrorSignal <- struct{}{}:
   460  		default:
   461  		}
   462  	}
   463  	return
   464  }
   465  
   466  func (conn *signalConn) Write(p []byte) (n int, err error) {
   467  	n, err = conn.Conn.Write(p)
   468  	if err != nil {
   469  		_ = conn.Conn.Close()
   470  		select {
   471  		case conn.ioErrorSignal <- struct{}{}:
   472  		default:
   473  		}
   474  	}
   475  	return
   476  }
   477  
   478  func (conn *signalConn) Wait() {
   479  	<-conn.ioErrorSignal
   480  }
   481  
   482  type testClient struct {
   483  	logger    *testLogger
   484  	unixConn  net.Conn
   485  	tunClient *Client
   486  }
   487  
   488  func startTestClient(
   489  	useIPv6 bool,
   490  	MTU int,
   491  	routeDestinations []string) (*testClient, error) {
   492  
   493  	unixConn, err := net.Dial("unix", UNIX_DOMAIN_SOCKET_NAME)
   494  	if err != nil {
   495  		return nil, fmt.Errorf("startTestClient(): net.Dial failed: %s", err)
   496  	}
   497  
   498  	logger := newTestLogger(false)
   499  
   500  	// Assumes IP addresses are available on test host
   501  
   502  	// TODO: assign unique IP to each testClient?
   503  
   504  	config := &ClientConfig{
   505  		Logger:                          logger,
   506  		SudoNetworkConfigCommands:       os.Getenv("TUN_TEST_SUDO") != "",
   507  		AllowNoIPv6NetworkConfiguration: !useIPv6,
   508  		IPv4AddressCIDR:                 "172.16.0.1/24",
   509  		IPv6AddressCIDR:                 "fd26:b6a6:4454:310a:0000:0000:0000:0001/64",
   510  		RouteDestinations:               routeDestinations,
   511  		Transport:                       unixConn,
   512  		MTU:                             MTU,
   513  	}
   514  
   515  	tunClient, err := NewClient(config)
   516  	if err != nil {
   517  		return nil, fmt.Errorf("startTestClient(): NewClient failed: %s", err)
   518  	}
   519  
   520  	// Configure kernel to fix issue described in fixBindToDevice
   521  
   522  	err = fixBindToDevice(logger, config.SudoNetworkConfigCommands, tunClient.device.Name())
   523  	if err != nil {
   524  		return nil, fmt.Errorf("startTestClient(): fixBindToDevice failed: %s", err)
   525  	}
   526  
   527  	tunClient.Start()
   528  
   529  	return &testClient{
   530  		logger:    logger,
   531  		unixConn:  unixConn,
   532  		tunClient: tunClient,
   533  	}, nil
   534  }
   535  
   536  func (client *testClient) stop() {
   537  	client.tunClient.Stop()
   538  	client.unixConn.Close()
   539  }
   540  
   541  type testTCPServer struct {
   542  	listenerIPAddress string
   543  	tcpListener       net.Listener
   544  	clientConns       *common.Conns
   545  	workers           *sync.WaitGroup
   546  }
   547  
   548  var errNoIPAddress = errors.New("no IP address")
   549  
   550  func startTestTCPServer(useIPv6 bool) (*testTCPServer, error) {
   551  
   552  	interfaceName := DEFAULT_PUBLIC_INTERFACE_NAME
   553  
   554  	hostIPaddress := ""
   555  
   556  	IPv4Address, IPv6Address, err := common.GetInterfaceIPAddresses(interfaceName)
   557  	if err != nil {
   558  		return nil, fmt.Errorf("startTestTCPServer(): GetInterfaceIPAddresses failed: %s", err)
   559  	}
   560  
   561  	if useIPv6 {
   562  		// Cannot route to link local address
   563  		if IPv6Address == nil || IPv6Address.IsLinkLocalUnicast() {
   564  			return nil, errNoIPAddress
   565  		}
   566  		hostIPaddress = IPv6Address.String()
   567  	} else {
   568  		if IPv4Address == nil || IPv4Address.IsLinkLocalUnicast() {
   569  			return nil, errNoIPAddress
   570  		}
   571  		hostIPaddress = IPv4Address.String()
   572  	}
   573  
   574  	tcpListener, err := net.Listen("tcp", net.JoinHostPort(hostIPaddress, strconv.Itoa(TCP_PORT)))
   575  	if err != nil {
   576  		return nil, fmt.Errorf("startTestTCPServer(): net.Listen failed: %s", err)
   577  	}
   578  
   579  	server := &testTCPServer{
   580  		listenerIPAddress: hostIPaddress,
   581  		tcpListener:       tcpListener,
   582  		clientConns:       common.NewConns(),
   583  		workers:           new(sync.WaitGroup),
   584  	}
   585  
   586  	server.workers.Add(1)
   587  	go server.run()
   588  
   589  	return server, nil
   590  }
   591  
   592  func (server *testTCPServer) getListenerIPAddress() string {
   593  	return server.listenerIPAddress
   594  }
   595  
   596  func (server *testTCPServer) run() {
   597  	defer server.workers.Done()
   598  
   599  	for {
   600  		clientConn, err := server.tcpListener.Accept()
   601  		if err != nil {
   602  			fmt.Printf("testTCPServer.run(): tcpListener.Accept failed: %s\n", err)
   603  			return
   604  		}
   605  
   606  		if !server.clientConns.Add(clientConn) {
   607  			return
   608  		}
   609  
   610  		server.workers.Add(1)
   611  		go func() {
   612  			defer server.workers.Done()
   613  			defer clientConn.Close()
   614  
   615  			buffer := make([]byte, TCP_RELAY_CHUNK_SIZE)
   616  
   617  			for {
   618  				_, err := io.ReadFull(clientConn, buffer)
   619  				if err != nil {
   620  					fmt.Printf("testTCPServer.run(): io.ReadFull failed: %s\n", err)
   621  					return
   622  				}
   623  				_, err = clientConn.Write(buffer)
   624  				if err != nil {
   625  					fmt.Printf("testTCPServer.run(): clientConn.Write failed: %s\n", err)
   626  					return
   627  				}
   628  			}
   629  		}()
   630  	}
   631  }
   632  
   633  func (server *testTCPServer) stop() {
   634  	server.clientConns.CloseAll()
   635  	server.tcpListener.Close()
   636  	server.workers.Wait()
   637  }
   638  
   639  type testTCPClient struct {
   640  	conn net.Conn
   641  }
   642  
   643  func startTestTCPClient(
   644  	tunDeviceName, serverIPAddress string) (*testTCPClient, error) {
   645  
   646  	// This is a simplified version of the low-level TCP dial
   647  	// code in psiphon/TCPConn, which supports BindToDevice.
   648  	// It does not resolve domain names and does not have an
   649  	// explicit timeout.
   650  
   651  	var ipv4 [4]byte
   652  	var ipv6 [16]byte
   653  	var domain int
   654  	var sockAddr syscall.Sockaddr
   655  
   656  	ipAddr := net.ParseIP(serverIPAddress)
   657  	if ipAddr == nil {
   658  		return nil, fmt.Errorf("net.ParseIP failed")
   659  	}
   660  
   661  	if ipAddr.To4() != nil {
   662  		copy(ipv4[:], ipAddr.To4())
   663  		domain = syscall.AF_INET
   664  		sockAddr = &syscall.SockaddrInet4{Addr: ipv4, Port: TCP_PORT}
   665  	} else {
   666  		copy(ipv6[:], ipAddr.To16())
   667  		domain = syscall.AF_INET6
   668  		sockAddr = &syscall.SockaddrInet6{Addr: ipv6, Port: TCP_PORT}
   669  	}
   670  
   671  	socketFd, err := syscall.Socket(domain, syscall.SOCK_STREAM, 0)
   672  	if err != nil {
   673  		return nil, fmt.Errorf("syscall.Socket failed: %s", err)
   674  	}
   675  
   676  	err = BindToDevice(socketFd, tunDeviceName)
   677  	if err != nil {
   678  		syscall.Close(socketFd)
   679  		return nil, fmt.Errorf("BindToDevice failed: %s", err)
   680  	}
   681  
   682  	err = syscall.Connect(socketFd, sockAddr)
   683  	if err != nil {
   684  		syscall.Close(socketFd)
   685  		return nil, fmt.Errorf("syscall.Connect failed: %s", err)
   686  	}
   687  
   688  	file := os.NewFile(uintptr(socketFd), "")
   689  	conn, err := net.FileConn(file)
   690  	file.Close()
   691  	if err != nil {
   692  		return nil, fmt.Errorf("net.FileConn failed: %s", err)
   693  	}
   694  
   695  	return &testTCPClient{
   696  		conn: conn,
   697  	}, nil
   698  }
   699  
   700  func (client *testTCPClient) Read(p []byte) (n int, err error) {
   701  	n, err = client.conn.Read(p)
   702  	return
   703  }
   704  
   705  func (client *testTCPClient) Write(p []byte) (n int, err error) {
   706  	n, err = client.conn.Write(p)
   707  	return
   708  }
   709  
   710  func (client *testTCPClient) stop() {
   711  	client.conn.Close()
   712  }
   713  
   714  func testDNSClient(useIPv6 bool, tunDeviceName string) error {
   715  
   716  	var ipv4 [4]byte
   717  	var ipv6 [16]byte
   718  	var domain int
   719  	var sockAddr syscall.Sockaddr
   720  
   721  	if !useIPv6 {
   722  		copy(ipv4[:], transparentDNSResolverIPv4Address)
   723  		domain = syscall.AF_INET
   724  		sockAddr = &syscall.SockaddrInet4{Addr: ipv4, Port: portNumberDNS}
   725  	} else {
   726  		copy(ipv6[:], transparentDNSResolverIPv6Address)
   727  		domain = syscall.AF_INET6
   728  		sockAddr = &syscall.SockaddrInet6{Addr: ipv6, Port: portNumberDNS}
   729  	}
   730  
   731  	socketFd, err := syscall.Socket(domain, syscall.SOCK_DGRAM, 0)
   732  	if err != nil {
   733  		return err
   734  	}
   735  
   736  	err = BindToDevice(socketFd, tunDeviceName)
   737  	if err != nil {
   738  		syscall.Close(socketFd)
   739  		return err
   740  	}
   741  
   742  	err = syscall.Connect(socketFd, sockAddr)
   743  	if err != nil {
   744  		syscall.Close(socketFd)
   745  		return err
   746  	}
   747  
   748  	file := os.NewFile(uintptr(socketFd), "")
   749  	conn, err := net.FileConn(file)
   750  	file.Close()
   751  	if err != nil {
   752  		return err
   753  	}
   754  	defer conn.Close()
   755  
   756  	dnsConn := &dns.Conn{Conn: conn}
   757  	defer dnsConn.Close()
   758  
   759  	query := new(dns.Msg)
   760  	query.SetQuestion(dns.Fqdn("www.example.org"), dns.TypeA)
   761  	query.RecursionDesired = true
   762  
   763  	dnsConn.WriteMsg(query)
   764  	_, err = dnsConn.ReadMsg()
   765  	if err != nil {
   766  		return err
   767  	}
   768  
   769  	return nil
   770  }
   771  
   772  type testLogger struct {
   773  	packetMetrics chan common.LogFields
   774  }
   775  
   776  func newTestLogger(wantLastPacketMetrics bool) *testLogger {
   777  
   778  	var packetMetrics chan common.LogFields
   779  	if wantLastPacketMetrics {
   780  		packetMetrics = make(chan common.LogFields, CONCURRENT_CLIENT_COUNT)
   781  	}
   782  
   783  	return &testLogger{
   784  		packetMetrics: packetMetrics,
   785  	}
   786  }
   787  
   788  func (logger *testLogger) WithTrace() common.LogTrace {
   789  	return &testLoggerTrace{trace: stacktrace.GetParentFunctionName()}
   790  }
   791  
   792  func (logger *testLogger) WithTraceFields(fields common.LogFields) common.LogTrace {
   793  	return &testLoggerTrace{
   794  		trace:  stacktrace.GetParentFunctionName(),
   795  		fields: fields,
   796  	}
   797  }
   798  
   799  func (logger *testLogger) LogMetric(metric string, fields common.LogFields) {
   800  
   801  	fmt.Printf("METRIC: %s: %+v\n", metric, fields)
   802  
   803  	if metric == "server_packet_metrics" && logger.packetMetrics != nil {
   804  		select {
   805  		case logger.packetMetrics <- fields:
   806  		default:
   807  		}
   808  	}
   809  }
   810  
   811  func (logger *testLogger) getLastPacketMetrics() common.LogFields {
   812  	if logger.packetMetrics == nil {
   813  		return nil
   814  	}
   815  
   816  	// Implicitly asserts that packet metrics will be emitted
   817  	// within PACKET_METRICS_TIMEOUT; if not, the test will fail.
   818  
   819  	select {
   820  	case fields := <-logger.packetMetrics:
   821  		return fields
   822  	case <-time.After(PACKET_METRICS_TIMEOUT):
   823  		return nil
   824  	}
   825  }
   826  
   827  type testLoggerTrace struct {
   828  	trace  string
   829  	fields common.LogFields
   830  }
   831  
   832  func (logger *testLoggerTrace) log(priority, message string) {
   833  	now := time.Now().UTC().Format(time.RFC3339)
   834  	if len(logger.fields) == 0 {
   835  		fmt.Printf(
   836  			"[%s] %s: %s: %s\n",
   837  			now, priority, logger.trace, message)
   838  	} else {
   839  		fmt.Printf(
   840  			"[%s] %s: %s: %s %+v\n",
   841  			now, priority, logger.trace, message, logger.fields)
   842  	}
   843  }
   844  
   845  func (logger *testLoggerTrace) Debug(args ...interface{}) {
   846  	logger.log("DEBUG", fmt.Sprint(args...))
   847  }
   848  
   849  func (logger *testLoggerTrace) Info(args ...interface{}) {
   850  	logger.log("INFO", fmt.Sprint(args...))
   851  }
   852  
   853  func (logger *testLoggerTrace) Warning(args ...interface{}) {
   854  	logger.log("WARNING", fmt.Sprint(args...))
   855  }
   856  
   857  func (logger *testLoggerTrace) Error(args ...interface{}) {
   858  	logger.log("ERROR", fmt.Sprint(args...))
   859  }