github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/packetman/packetman_linux_test.go (about)

     1  // +build PSIPHON_RUN_PACKET_MANIPULATOR_TEST
     2  
     3  /*
     4   * Copyright (c) 2020, Psiphon Inc.
     5   * All rights reserved.
     6   *
     7   * This program is free software: you can redistribute it and/or modify
     8   * it under the terms of the GNU General Public License as published by
     9   * the Free Software Foundation, either version 3 of the License, or
    10   * (at your option) any later version.
    11   *
    12   * This program is distributed in the hope that it will be useful,
    13   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    14   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    15   * GNU General Public License for more details.
    16   *
    17   * You should have received a copy of the GNU General Public License
    18   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    19   *
    20   */
    21  
    22  package packetman
    23  
    24  import (
    25  	"fmt"
    26  	"io"
    27  	"io/ioutil"
    28  	"net"
    29  	"net/http"
    30  	"strconv"
    31  	"testing"
    32  	"time"
    33  
    34  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    35  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/stacktrace"
    36  )
    37  
    38  func TestPacketManipulatorIPv4(t *testing.T) {
    39  	testPacketManipulator(false, t)
    40  }
    41  
    42  func TestPacketManipulatorIPv6(t *testing.T) {
    43  	testPacketManipulator(true, t)
    44  }
    45  
    46  func testPacketManipulator(useIPv6 bool, t *testing.T) {
    47  
    48  	// Test: run a Manipulator in front of a web server; make an HTTP request;
    49  	// the expected transformation spec should be executed (as reported by
    50  	// GetAppliedSpecName) and the request must succeed.
    51  
    52  	ipv4, ipv6, err := common.GetRoutableInterfaceIPAddresses()
    53  	if err != nil {
    54  		t.Fatalf("GetRoutableInterfaceIPAddressesfailed: %v", err)
    55  	}
    56  
    57  	network := "tcp4"
    58  	address := net.JoinHostPort(ipv4.String(), "0")
    59  	if useIPv6 {
    60  		if ipv6 == nil {
    61  			t.Skipf("test unsupported: no IP address")
    62  		}
    63  		network = "tcp6"
    64  		address = net.JoinHostPort(ipv6.String(), "0")
    65  	}
    66  
    67  	listener, err := net.Listen(network, address)
    68  	if err != nil {
    69  		t.Fatalf("net.Listen failed: %v", err)
    70  	}
    71  	defer listener.Close()
    72  
    73  	hostStr, portStr, err := net.SplitHostPort(listener.Addr().String())
    74  	if err != nil {
    75  		t.Fatalf("net.SplitHostPort failed: %s", err.Error())
    76  	}
    77  	listenerPort, _ := strconv.Atoi(portStr)
    78  
    79  	// [["TCP-flags S"]] replaces the original SYN-ACK packet with a single
    80  	// SYN packet, implementing TCP simultaneous open.
    81  
    82  	testSpecName := "test-spec"
    83  	extraDataValue := "extra-data"
    84  	config := &Config{
    85  		Logger:        newTestLogger(),
    86  		ProtocolPorts: []int{listenerPort},
    87  		Specs:         []*Spec{&Spec{Name: testSpecName, PacketSpecs: [][]string{[]string{"TCP-flags S"}}}},
    88  		SelectSpecName: func(protocolPort int, _ net.IP) (string, interface{}) {
    89  			if protocolPort == listenerPort {
    90  				return testSpecName, extraDataValue
    91  			}
    92  			return "", nil
    93  		},
    94  		QueueNumber: 1,
    95  	}
    96  
    97  	m, err := NewManipulator(config)
    98  	if err != nil {
    99  		t.Fatalf("NewManipulator failed: %v", err)
   100  	}
   101  
   102  	err = m.Start()
   103  	if err != nil {
   104  		t.Fatalf("Manipulator.Start failed: %v", err)
   105  	}
   106  	defer m.Stop()
   107  
   108  	go func() {
   109  		serveMux := http.NewServeMux()
   110  		serveMux.HandleFunc("/", func(w http.ResponseWriter, _ *http.Request) {
   111  			io.WriteString(w, "test-response\n")
   112  		})
   113  
   114  		server := &http.Server{
   115  			Handler: serveMux,
   116  			ConnState: func(conn net.Conn, state http.ConnState) {
   117  				if state == http.StateNew {
   118  					localAddr := conn.LocalAddr().(*net.TCPAddr)
   119  					remoteAddr := conn.RemoteAddr().(*net.TCPAddr)
   120  					specName, extraData, err := m.GetAppliedSpecName(localAddr, remoteAddr)
   121  					if err != nil {
   122  						t.Fatalf("GetAppliedSpecName failed: %v", err)
   123  					}
   124  					if specName != testSpecName {
   125  						t.Fatalf("unexpected spec name: %s", specName)
   126  					}
   127  					extraDataStr, ok := extraData.(string)
   128  					if !ok || extraDataStr != extraDataValue {
   129  						t.Fatalf("unexpected extra data value: %v", extraData)
   130  					}
   131  				}
   132  			},
   133  		}
   134  
   135  		server.Serve(listener)
   136  	}()
   137  
   138  	httpClient := &http.Client{
   139  		Timeout: 30 * time.Second,
   140  	}
   141  
   142  	response, err := httpClient.Get(fmt.Sprintf("http://%s:%s", hostStr, portStr))
   143  	if err != nil {
   144  		t.Fatalf("http.Get failed: %v", err)
   145  	}
   146  	defer response.Body.Close()
   147  	_, err = ioutil.ReadAll(response.Body)
   148  	if err != nil {
   149  		t.Fatalf("ioutil.ReadAll failed: %v", err)
   150  	}
   151  
   152  	if response.StatusCode != http.StatusOK {
   153  		t.Fatalf("unexpected response code: %d", response.StatusCode)
   154  	}
   155  }
   156  
   157  func newTestLogger() common.Logger {
   158  	return &testLogger{}
   159  }
   160  
   161  type testLogger struct {
   162  }
   163  
   164  func (logger *testLogger) WithTrace() common.LogTrace {
   165  	return &testLogTrace{
   166  		trace: stacktrace.GetParentFunctionName(),
   167  	}
   168  }
   169  
   170  func (logger *testLogger) WithTraceFields(fields common.LogFields) common.LogTrace {
   171  	return &testLogTrace{
   172  		trace:  stacktrace.GetParentFunctionName(),
   173  		fields: fields,
   174  	}
   175  }
   176  
   177  func (logger *testLogger) LogMetric(metric string, fields common.LogFields) {
   178  }
   179  
   180  type testLogTrace struct {
   181  	trace  string
   182  	fields common.LogFields
   183  }
   184  
   185  func (log *testLogTrace) log(
   186  	noticeType string, args ...interface{}) {
   187  
   188  	fmt.Printf("[%s] %s: %+v: %s\n",
   189  		noticeType,
   190  		log.trace,
   191  		log.fields,
   192  		fmt.Sprint(args...))
   193  }
   194  
   195  func (log *testLogTrace) Debug(args ...interface{}) {
   196  	log.log("DEBUG", args...)
   197  }
   198  
   199  func (log *testLogTrace) Info(args ...interface{}) {
   200  	log.log("INFO", args...)
   201  }
   202  
   203  func (log *testLogTrace) Warning(args ...interface{}) {
   204  	log.log("ALERT", args...)
   205  }
   206  
   207  func (log *testLogTrace) Error(args ...interface{}) {
   208  	log.log("ERROR", args...)
   209  }