github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/server/passthrough_test.go (about)

     1  /*
     2   * Copyright (c) 2020, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package server
    21  
    22  import (
    23  	"bytes"
    24  	"context"
    25  	"crypto/tls"
    26  	"crypto/x509"
    27  	"encoding/json"
    28  	"errors"
    29  	"fmt"
    30  	"io/ioutil"
    31  	"net"
    32  	"net/http"
    33  	"os"
    34  	"path/filepath"
    35  	"sync"
    36  	"sync/atomic"
    37  	"testing"
    38  	"time"
    39  
    40  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon"
    41  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    42  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
    43  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/protocol"
    44  )
    45  
    46  func TestPassthrough(t *testing.T) {
    47  	testPassthrough(t, false)
    48  }
    49  
    50  func TestLegacyPassthrough(t *testing.T) {
    51  	testPassthrough(t, true)
    52  }
    53  
    54  func testPassthrough(t *testing.T, legacy bool) {
    55  
    56  	psiphon.SetEmitDiagnosticNotices(true, true)
    57  
    58  	// Run passthrough web server
    59  
    60  	webServerCertificate, webServerPrivateKey, err := common.GenerateWebServerCertificate("example.org")
    61  	if err != nil {
    62  		t.Fatalf("common.GenerateWebServerCertificate failed: %s", err)
    63  	}
    64  
    65  	webListener, err := net.Listen("tcp", "127.0.0.1:0")
    66  	if err != nil {
    67  		t.Fatalf("net.Listen failed: %s", err)
    68  	}
    69  	defer webListener.Close()
    70  
    71  	webCertificate, err := tls.X509KeyPair(
    72  		[]byte(webServerCertificate),
    73  		[]byte(webServerPrivateKey))
    74  	if err != nil {
    75  		t.Fatalf("tls.X509KeyPair failed: %s", err)
    76  	}
    77  
    78  	webListener = tls.NewListener(webListener, &tls.Config{
    79  		Certificates: []tls.Certificate{webCertificate},
    80  	})
    81  
    82  	webServerAddress := webListener.Addr().String()
    83  
    84  	webResponseBody := []byte(prng.HexString(32))
    85  
    86  	webServer := http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
    87  		w.Write(webResponseBody)
    88  	})
    89  
    90  	go func() {
    91  		http.Serve(webListener, webServer)
    92  	}()
    93  
    94  	// Run Psiphon server
    95  
    96  	tunnelProtocol := protocol.TUNNEL_PROTOCOL_UNFRONTED_MEEK_SESSION_TICKET
    97  
    98  	generateConfigParams := &GenerateConfigParams{
    99  		ServerIPAddress:      "127.0.0.1",
   100  		EnableSSHAPIRequests: true,
   101  		WebServerPort:        8000,
   102  		TunnelProtocolPorts:  map[string]int{tunnelProtocol: 4000},
   103  		Passthrough:          true,
   104  		LegacyPassthrough:    legacy,
   105  	}
   106  
   107  	serverConfigJSON, _, _, _, encodedServerEntry, err := GenerateConfig(generateConfigParams)
   108  	if err != nil {
   109  		t.Fatalf("error generating server config: %s", err)
   110  	}
   111  
   112  	var serverConfig map[string]interface{}
   113  	json.Unmarshal(serverConfigJSON, &serverConfig)
   114  
   115  	serverConfig["LogFilename"] = filepath.Join(testDataDirName, "psiphond.log")
   116  	serverConfig["LogLevel"] = "debug"
   117  	serverConfig["TunnelProtocolPassthroughAddresses"] = map[string]string{tunnelProtocol: webServerAddress}
   118  
   119  	serverConfigJSON, _ = json.Marshal(serverConfig)
   120  
   121  	serverWaitGroup := new(sync.WaitGroup)
   122  	serverWaitGroup.Add(1)
   123  	go func() {
   124  		defer serverWaitGroup.Done()
   125  		err := RunServices(serverConfigJSON)
   126  		if err != nil {
   127  			t.Errorf("error running server: %s", err)
   128  		}
   129  	}()
   130  
   131  	defer func() {
   132  		p, _ := os.FindProcess(os.Getpid())
   133  		p.Signal(os.Interrupt)
   134  		serverWaitGroup.Wait()
   135  	}()
   136  
   137  	// TODO: monitor logs for more robust wait-until-loaded.
   138  	time.Sleep(1 * time.Second)
   139  
   140  	// Test: normal client connects successfully
   141  
   142  	clientConfigJSON := fmt.Sprintf(`
   143  		    {
   144  		    	"DataRootDirectory" : "%s",
   145  		        "ClientPlatform" : "Windows",
   146  		        "ClientVersion" : "0",
   147  		        "SponsorId" : "0",
   148  		        "PropagationChannelId" : "0",
   149  		        "TargetServerEntry" : "%s"
   150  		    }`, testDataDirName, string(encodedServerEntry))
   151  
   152  	clientConfig, err := psiphon.LoadConfig([]byte(clientConfigJSON))
   153  	if err != nil {
   154  		t.Fatalf("error processing configuration file: %s", err)
   155  	}
   156  
   157  	err = clientConfig.Commit(false)
   158  	if err != nil {
   159  		t.Fatalf("error committing configuration file: %s", err)
   160  	}
   161  
   162  	err = psiphon.OpenDataStore(clientConfig)
   163  	if err != nil {
   164  		t.Fatalf("error initializing client datastore: %s", err)
   165  	}
   166  	defer psiphon.CloseDataStore()
   167  
   168  	controller, err := psiphon.NewController(clientConfig)
   169  	if err != nil {
   170  		t.Fatalf("error creating client controller: %s", err)
   171  	}
   172  
   173  	tunnelEstablished := make(chan struct{}, 1)
   174  
   175  	psiphon.SetNoticeWriter(psiphon.NewNoticeReceiver(
   176  		func(notice []byte) {
   177  			noticeType, payload, err := psiphon.GetNotice(notice)
   178  			if err != nil {
   179  				return
   180  			}
   181  			if noticeType == "Tunnels" {
   182  				count := int(payload["count"].(float64))
   183  				if count >= 1 {
   184  					tunnelEstablished <- struct{}{}
   185  				}
   186  			}
   187  		}))
   188  
   189  	ctx, cancelFunc := context.WithCancel(context.Background())
   190  	controllerWaitGroup := new(sync.WaitGroup)
   191  	controllerWaitGroup.Add(1)
   192  	go func() {
   193  		defer controllerWaitGroup.Done()
   194  		controller.Run(ctx)
   195  	}()
   196  	<-tunnelEstablished
   197  	cancelFunc()
   198  	controllerWaitGroup.Wait()
   199  
   200  	// Test: passthrough
   201  
   202  	// Non-psiphon HTTPS request routed to passthrough web server
   203  
   204  	verifiedCertificate := int32(0)
   205  
   206  	httpClient := &http.Client{
   207  		Transport: &http.Transport{
   208  			TLSClientConfig: &tls.Config{
   209  				InsecureSkipVerify: true,
   210  				VerifyPeerCertificate: func(rawCerts [][]byte, _ [][]*x509.Certificate) error {
   211  					if len(rawCerts) < 1 {
   212  						return errors.New("no certificate to verify")
   213  					}
   214  					if !bytes.Equal(rawCerts[0], []byte(webCertificate.Certificate[0])) {
   215  						return errors.New("unexpected certificate")
   216  					}
   217  					atomic.StoreInt32(&verifiedCertificate, 1)
   218  					return nil
   219  				},
   220  			},
   221  		},
   222  	}
   223  
   224  	response, err := httpClient.Get("https://" + webServerAddress)
   225  	if err != nil {
   226  		t.Fatalf("http.Get failed: %s", err)
   227  	}
   228  	defer response.Body.Close()
   229  
   230  	if atomic.LoadInt32(&verifiedCertificate) != 1 {
   231  		t.Fatalf("certificate not verified")
   232  	}
   233  
   234  	if response.StatusCode != http.StatusOK {
   235  		t.Fatalf("unexpected response.StatusCode: %d", response.StatusCode)
   236  	}
   237  
   238  	responseBody, err := ioutil.ReadAll(response.Body)
   239  	if err != nil {
   240  		t.Fatalf("ioutil.ReadAll failed: %s", err)
   241  	}
   242  
   243  	if !bytes.Equal(responseBody, webResponseBody) {
   244  		t.Fatalf("unexpected responseBody: %s", string(responseBody))
   245  	}
   246  }