github.com/psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/remoteServerList_test.go (about)

     1  /*
     2   * Copyright (c) 2016, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package psiphon
    21  
    22  import (
    23  	"bytes"
    24  	"context"
    25  	"crypto/md5"
    26  	"encoding/base64"
    27  	"encoding/hex"
    28  	"fmt"
    29  	"io"
    30  	"io/ioutil"
    31  	"net"
    32  	"net/http"
    33  	"net/url"
    34  	"os"
    35  	"path"
    36  	"path/filepath"
    37  	"sync"
    38  	"sync/atomic"
    39  	"syscall"
    40  	"testing"
    41  	"time"
    42  
    43  	socks "github.com/Psiphon-Labs/goptlib"
    44  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    45  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/osl"
    46  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
    47  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/server"
    48  )
    49  
    50  // TODO: TestCommonRemoteServerList (this is currently covered by controller_test.go)
    51  
    52  func TestObfuscatedRemoteServerLists(t *testing.T) {
    53  	testObfuscatedRemoteServerLists(t, false)
    54  }
    55  
    56  func TestObfuscatedRemoteServerListsOmitMD5Sums(t *testing.T) {
    57  	testObfuscatedRemoteServerLists(t, true)
    58  }
    59  
    60  // Each instance testObfuscatedRemoteServerLists runs a server which binds to
    61  // specific network ports. Server shutdown, via SIGTERM, is not supported on
    62  // Windows. Shutdown is not necessary for these tests, but, without shutdown,
    63  // multiple testObfuscatedRemoteServerLists calls fail when trying to reuse
    64  // network ports. This workaround selects unique ports for each server.
    65  var nextServerPort int32 = 8000
    66  
    67  func testObfuscatedRemoteServerLists(t *testing.T, omitMD5Sums bool) {
    68  
    69  	testDataDirName, err := ioutil.TempDir("", "psiphon-remote-server-list-test")
    70  	if err != nil {
    71  		t.Fatalf("TempDir failed: %s", err)
    72  	}
    73  	defer os.RemoveAll(testDataDirName)
    74  
    75  	//
    76  	// create a server
    77  	//
    78  
    79  	serverIPv4Address, serverIPv6Address, err := common.GetRoutableInterfaceIPAddresses()
    80  	if err != nil {
    81  		t.Fatalf("error getting server IP address: %s", err)
    82  	}
    83  	serverIPAddress := ""
    84  	if serverIPv4Address != nil {
    85  		serverIPAddress = serverIPv4Address.String()
    86  	} else {
    87  		serverIPAddress = serverIPv6Address.String()
    88  	}
    89  
    90  	serverConfigJSON, _, _, _, encodedServerEntry, err := server.GenerateConfig(
    91  		&server.GenerateConfigParams{
    92  			ServerIPAddress:      serverIPAddress,
    93  			EnableSSHAPIRequests: true,
    94  			WebServerPort:        int(atomic.AddInt32(&nextServerPort, 1)),
    95  			TunnelProtocolPorts:  map[string]int{"OSSH": int(atomic.AddInt32(&nextServerPort, 1))},
    96  			LogFilename:          filepath.Join(testDataDirName, "psiphond.log"),
    97  			LogLevel:             "debug",
    98  
    99  			// "defer os.RemoveAll" will cause a log write error
   100  			SkipPanickingLogWriter: true,
   101  		})
   102  	if err != nil {
   103  		t.Fatalf("error generating server config: %s", err)
   104  	}
   105  
   106  	//
   107  	// pave OSLs
   108  	//
   109  
   110  	oslConfigJSONTemplate := `
   111      {
   112        "Schemes" : [
   113          {
   114            "Epoch" : "%s",
   115            "Regions" : [],
   116            "PropagationChannelIDs" : ["%s"],
   117            "MasterKey" : "vwab2WY3eNyMBpyFVPtsivMxF4MOpNHM/T7rHJIXctg=",
   118            "SeedSpecs" : [
   119              {
   120                "ID" : "KuP2V6gLcROIFzb/27fUVu4SxtEfm2omUoISlrWv1mA=",
   121                "UpstreamSubnets" : ["0.0.0.0/0"],
   122                "Targets" :
   123                {
   124                    "BytesRead" : 1,
   125                    "BytesWritten" : 1,
   126                    "PortForwardDurationNanoseconds" : 1
   127                }
   128              }
   129            ],
   130            "SeedSpecThreshold" : 1,
   131            "SeedPeriodNanoseconds" : %d,
   132            "SeedPeriodKeySplits": [
   133              {
   134                "Total": 1,
   135                "Threshold": 1
   136              }
   137            ]
   138          }
   139        ]
   140      }`
   141  
   142  	now := time.Now().UTC()
   143  	seedPeriod := 24 * time.Hour
   144  	epoch := now.Truncate(seedPeriod)
   145  	epochStr := epoch.Format(time.RFC3339Nano)
   146  
   147  	propagationChannelID := prng.HexString(8)
   148  
   149  	oslConfigJSON := fmt.Sprintf(
   150  		oslConfigJSONTemplate,
   151  		epochStr,
   152  		propagationChannelID,
   153  		seedPeriod)
   154  
   155  	oslConfig, err := osl.LoadConfig([]byte(oslConfigJSON))
   156  	if err != nil {
   157  		t.Fatalf("error loading OSL config: %s", err)
   158  	}
   159  
   160  	signingPublicKey, signingPrivateKey, err := common.GenerateAuthenticatedDataPackageKeys()
   161  	if err != nil {
   162  		t.Fatalf("error generating package keys: %s", err)
   163  	}
   164  
   165  	var omitMD5SumsSchemes []int
   166  	if omitMD5Sums {
   167  		omitMD5SumsSchemes = []int{0}
   168  	}
   169  	// First Pave() call is to get the OSL ID to pave into
   170  
   171  	oslID := ""
   172  
   173  	omitEmptyOSLsSchemes := []int{}
   174  
   175  	paveFiles, err := oslConfig.Pave(
   176  		time.Time{},
   177  		epoch,
   178  		propagationChannelID,
   179  		signingPublicKey,
   180  		signingPrivateKey,
   181  		map[string][]string{},
   182  		omitMD5SumsSchemes,
   183  		omitEmptyOSLsSchemes,
   184  		func(logInfo *osl.PaveLogInfo) {
   185  			oslID = logInfo.OSLID
   186  		})
   187  	if err != nil {
   188  		t.Fatalf("error paving OSL files: %s", err)
   189  	}
   190  
   191  	omitEmptyOSLsSchemes = []int{0}
   192  
   193  	paveFiles, err = oslConfig.Pave(
   194  		time.Time{},
   195  		epoch,
   196  		propagationChannelID,
   197  		signingPublicKey,
   198  		signingPrivateKey,
   199  		map[string][]string{
   200  			oslID: {string(encodedServerEntry)},
   201  		},
   202  		omitMD5SumsSchemes,
   203  		omitEmptyOSLsSchemes,
   204  		nil)
   205  	if err != nil {
   206  		t.Fatalf("error paving OSL files: %s", err)
   207  	}
   208  
   209  	//
   210  	// mock seeding SLOKs
   211  	//
   212  
   213  	config := Config{
   214  		DataRootDirectory:    testDataDirName,
   215  		PropagationChannelId: "0",
   216  		SponsorId:            "0"}
   217  	err = config.Commit(false)
   218  	if err != nil {
   219  		t.Fatalf("Error initializing config: %s", err)
   220  	}
   221  
   222  	err = OpenDataStore(&config)
   223  	if err != nil {
   224  		t.Fatalf("error initializing client datastore: %s", err)
   225  	}
   226  	defer CloseDataStore()
   227  
   228  	if CountServerEntries() > 0 {
   229  		t.Fatalf("unexpected server entries")
   230  	}
   231  
   232  	seedState := oslConfig.NewClientSeedState("", propagationChannelID, nil)
   233  	seedPortForward := seedState.NewClientSeedPortForward(net.ParseIP("0.0.0.0"))
   234  	seedPortForward.UpdateProgress(1, 1, 1)
   235  	payload := seedState.GetSeedPayload()
   236  	if len(payload.SLOKs) != 1 {
   237  		t.Fatalf("expected 1 SLOKs, got %d", len(payload.SLOKs))
   238  	}
   239  
   240  	SetSLOK(payload.SLOKs[0].ID, payload.SLOKs[0].Key)
   241  
   242  	//
   243  	// run mock remote server list host
   244  	//
   245  
   246  	// Exercise using multiple download URLs
   247  
   248  	var remoteServerListListeners [2]net.Listener
   249  	var remoteServerListHostAddresses [2]string
   250  
   251  	for i := 0; i < len(remoteServerListListeners); i++ {
   252  		remoteServerListListeners[i], err = net.Listen("tcp", net.JoinHostPort(serverIPAddress, "0"))
   253  		if err != nil {
   254  			t.Fatalf("net.Listen error: %s", err)
   255  		}
   256  		defer remoteServerListListeners[i].Close()
   257  		remoteServerListHostAddresses[i] = remoteServerListListeners[i].Addr().String()
   258  	}
   259  
   260  	// The common remote server list fetches will 404
   261  	remoteServerListURL := fmt.Sprintf("http://%s/server_list_compressed", remoteServerListHostAddresses[0])
   262  
   263  	obfuscatedServerListRootURLsJSONConfig := "["
   264  	obfuscatedServerListRootURLs := make([]string, len(remoteServerListHostAddresses))
   265  
   266  	httpServers := make(chan *http.Server, len(remoteServerListHostAddresses))
   267  
   268  	for i := 0; i < len(remoteServerListHostAddresses); i++ {
   269  
   270  		obfuscatedServerListRootURLs[i] = fmt.Sprintf("http://%s/", remoteServerListHostAddresses[i])
   271  
   272  		obfuscatedServerListRootURLsJSONConfig += fmt.Sprintf(
   273  			"{\"URL\" : \"%s\"}", base64.StdEncoding.EncodeToString([]byte(obfuscatedServerListRootURLs[i])))
   274  		if i == len(remoteServerListHostAddresses)-1 {
   275  			obfuscatedServerListRootURLsJSONConfig += "]"
   276  		} else {
   277  			obfuscatedServerListRootURLsJSONConfig += ","
   278  		}
   279  
   280  		go func(listener net.Listener, remoteServerListHostAddress string) {
   281  			startTime := time.Now()
   282  			serveMux := http.NewServeMux()
   283  			for _, paveFile := range paveFiles {
   284  				file := paveFile
   285  				serveMux.HandleFunc("/"+file.Name, func(w http.ResponseWriter, req *http.Request) {
   286  					md5sum := md5.Sum(file.Contents)
   287  					w.Header().Add("Content-Type", "application/octet-stream")
   288  					w.Header().Add("ETag", fmt.Sprintf("\"%s\"", hex.EncodeToString(md5sum[:])))
   289  					http.ServeContent(w, req, file.Name, startTime, bytes.NewReader(file.Contents))
   290  				})
   291  			}
   292  			httpServer := &http.Server{
   293  				Addr:    remoteServerListHostAddress,
   294  				Handler: serveMux,
   295  			}
   296  			httpServers <- httpServer
   297  			httpServer.Serve(listener)
   298  		}(remoteServerListListeners[i], remoteServerListHostAddresses[i])
   299  	}
   300  
   301  	defer func() {
   302  		for i := 0; i < len(remoteServerListHostAddresses); i++ {
   303  			httpServer := <-httpServers
   304  			httpServer.Close()
   305  		}
   306  	}()
   307  
   308  	//
   309  	// run Psiphon server
   310  	//
   311  
   312  	go func() {
   313  		err := server.RunServices(serverConfigJSON)
   314  		if err != nil {
   315  			// TODO: wrong goroutine for t.FatalNow()
   316  			t.Errorf("error running server: %s", err)
   317  		}
   318  	}()
   319  
   320  	process, err := os.FindProcess(os.Getpid())
   321  	if err != nil {
   322  		t.Fatalf("os.FindProcess error: %s", err)
   323  	}
   324  	defer process.Signal(syscall.SIGTERM)
   325  
   326  	//
   327  	// disrupt remote server list downloads
   328  	//
   329  
   330  	disruptorListener, err := net.Listen("tcp", "127.0.0.1:0")
   331  	if err != nil {
   332  		t.Fatalf("net.Listen error: %s", err)
   333  	}
   334  	defer disruptorListener.Close()
   335  
   336  	disruptorProxyAddress := disruptorListener.Addr().String()
   337  	disruptorProxyURL := "socks4a://" + disruptorProxyAddress
   338  
   339  	go func() {
   340  		listener := socks.NewSocksListener(disruptorListener)
   341  		for {
   342  			localConn, err := listener.AcceptSocks()
   343  			if err != nil {
   344  				if e, ok := err.(net.Error); ok && e.Temporary() {
   345  					fmt.Printf("disruptor proxy temporary accept error: %s\n", err)
   346  					continue
   347  				}
   348  				fmt.Printf("disruptor proxy accept error: %s\n", err)
   349  				return
   350  			}
   351  			go func() {
   352  				remoteConn, err := net.Dial("tcp", localConn.Req.Target)
   353  				if err != nil {
   354  					fmt.Printf("disruptor proxy dial error: %s\n", err)
   355  					return
   356  				}
   357  				err = localConn.Grant(&net.TCPAddr{IP: net.ParseIP("0.0.0.0"), Port: 0})
   358  				if err != nil {
   359  					fmt.Printf("disruptor proxy grant error: %s\n", err)
   360  					return
   361  				}
   362  
   363  				waitGroup := new(sync.WaitGroup)
   364  				waitGroup.Add(1)
   365  				go func() {
   366  					defer waitGroup.Done()
   367  					io.Copy(remoteConn, localConn)
   368  				}()
   369  				if common.Contains(remoteServerListHostAddresses[:], localConn.Req.Target) {
   370  					io.CopyN(localConn, remoteConn, 500)
   371  				} else {
   372  					io.Copy(localConn, remoteConn)
   373  				}
   374  				localConn.Close()
   375  				remoteConn.Close()
   376  				waitGroup.Wait()
   377  			}()
   378  		}
   379  	}()
   380  
   381  	//
   382  	// connect to Psiphon server with Psiphon client
   383  	//
   384  
   385  	SetEmitDiagnosticNotices(true, true)
   386  
   387  	// Note: calling LoadConfig ensures all *int config fields are initialized
   388  	clientConfigJSONTemplate := `
   389      {
   390          "ClientPlatform" : "",
   391          "ClientVersion" : "0",
   392          "SponsorId" : "0",
   393          "PropagationChannelId" : "0",
   394          "ConnectionWorkerPoolSize" : 1,
   395          "EstablishTunnelPausePeriodSeconds" : 1,
   396          "FetchRemoteServerListRetryPeriodMilliseconds" : 250,
   397          "RemoteServerListSignaturePublicKey" : "%s",
   398          "RemoteServerListUrl" : "%s",
   399          "ObfuscatedServerListRootURLs" : %s,
   400          "UpstreamProxyUrl" : "%s",
   401          "UpstreamProxyAllowAllServerEntrySources" : true
   402      }`
   403  
   404  	clientConfigJSON := fmt.Sprintf(
   405  		clientConfigJSONTemplate,
   406  		signingPublicKey,
   407  		remoteServerListURL,
   408  		obfuscatedServerListRootURLsJSONConfig,
   409  		disruptorProxyURL)
   410  
   411  	clientConfig, err := LoadConfig([]byte(clientConfigJSON))
   412  	if err != nil {
   413  		t.Fatalf("error processing configuration file: %s", err)
   414  	}
   415  
   416  	clientConfig.DataRootDirectory = testDataDirName
   417  
   418  	err = clientConfig.Commit(false)
   419  	if err != nil {
   420  		t.Fatalf("error committing configuration file: %s", err)
   421  	}
   422  
   423  	controller, err := NewController(clientConfig)
   424  	if err != nil {
   425  		t.Fatalf("error creating client controller: %s", err)
   426  	}
   427  
   428  	tunnelEstablished := make(chan struct{}, 1)
   429  
   430  	SetNoticeWriter(NewNoticeReceiver(
   431  		func(notice []byte) {
   432  
   433  			noticeType, payload, err := GetNotice(notice)
   434  			if err != nil {
   435  				return
   436  			}
   437  
   438  			printNotice := false
   439  
   440  			switch noticeType {
   441  			case "Tunnels":
   442  				printNotice = true
   443  				count := int(payload["count"].(float64))
   444  				if count == 1 {
   445  					tunnelEstablished <- struct{}{}
   446  				}
   447  			case "RemoteServerListResourceDownloadedBytes":
   448  				// TODO: check for resumed download for each URL
   449  				//url := payload["url"].(string)
   450  				//printNotice = true
   451  				printNotice = false
   452  			case "RemoteServerListResourceDownloaded":
   453  				printNotice = true
   454  			}
   455  
   456  			if printNotice {
   457  				fmt.Printf("%s\n", string(notice))
   458  			}
   459  		}))
   460  
   461  	ctx, cancelFunc := context.WithCancel(context.Background())
   462  	defer cancelFunc()
   463  
   464  	go func() {
   465  		controller.Run(ctx)
   466  	}()
   467  
   468  	establishTimeout := time.NewTimer(30 * time.Second)
   469  	select {
   470  	case <-tunnelEstablished:
   471  	case <-establishTimeout.C:
   472  		t.Fatalf("tunnel establish timeout exceeded")
   473  	}
   474  
   475  	for _, paveFile := range paveFiles {
   476  		u, _ := url.Parse(obfuscatedServerListRootURLs[0])
   477  		u.Path = path.Join(u.Path, paveFile.Name)
   478  		etag, _ := GetUrlETag(u.String())
   479  		md5sum := md5.Sum(paveFile.Contents)
   480  		if etag != fmt.Sprintf("\"%s\"", hex.EncodeToString(md5sum[:])) {
   481  			t.Fatalf("unexpected ETag for %s", u)
   482  		}
   483  	}
   484  }