github.com/Psiphon-Labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/interrupt_dials_test.go (about)

     1  /*
     2   * Copyright (c) 2017, Psiphon Inc.
     3   * All rights reserved.
     4   *
     5   * This program is free software: you can redistribute it and/or modify
     6   * it under the terms of the GNU General Public License as published by
     7   * the Free Software Foundation, either version 3 of the License, or
     8   * (at your option) any later version.
     9   *
    10   * This program is distributed in the hope that it will be useful,
    11   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    12   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    13   * GNU General Public License for more details.
    14   *
    15   * You should have received a copy of the GNU General Public License
    16   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    17   *
    18   */
    19  
    20  package psiphon
    21  
    22  import (
    23  	"context"
    24  	"fmt"
    25  	"net"
    26  	"runtime"
    27  	"strings"
    28  	"sync"
    29  	"testing"
    30  	"time"
    31  
    32  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    33  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters"
    34  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
    35  )
    36  
    37  func TestInterruptDials(t *testing.T) {
    38  
    39  	resolveIP := func(_ context.Context, host string) ([]net.IP, error) {
    40  		return []net.IP{net.ParseIP(host)}, nil
    41  	}
    42  
    43  	makeDialers := make(map[string]func(string) common.Dialer)
    44  
    45  	makeDialers["TCP"] = func(string) common.Dialer {
    46  		return NewTCPDialer(&DialConfig{ResolveIP: resolveIP})
    47  	}
    48  
    49  	makeDialers["SOCKS4-Proxied"] = func(mockServerAddr string) common.Dialer {
    50  		return NewTCPDialer(
    51  			&DialConfig{
    52  				ResolveIP:        resolveIP,
    53  				UpstreamProxyURL: "socks4a://" + mockServerAddr,
    54  			})
    55  	}
    56  
    57  	makeDialers["SOCKS5-Proxied"] = func(mockServerAddr string) common.Dialer {
    58  		return NewTCPDialer(
    59  			&DialConfig{
    60  				ResolveIP:        resolveIP,
    61  				UpstreamProxyURL: "socks5://" + mockServerAddr,
    62  			})
    63  	}
    64  
    65  	makeDialers["HTTP-CONNECT-Proxied"] = func(mockServerAddr string) common.Dialer {
    66  		return NewTCPDialer(
    67  			&DialConfig{
    68  				ResolveIP:        resolveIP,
    69  				UpstreamProxyURL: "http://" + mockServerAddr,
    70  			})
    71  	}
    72  
    73  	// TODO: test upstreamproxy.ProxyAuthTransport
    74  
    75  	params, err := parameters.NewParameters(nil)
    76  	if err != nil {
    77  		t.Fatalf("NewParameters failed: %s", err)
    78  	}
    79  
    80  	seed, err := prng.NewSeed()
    81  	if err != nil {
    82  		t.Fatalf("NewSeed failed: %s", err)
    83  	}
    84  
    85  	makeDialers["TLS"] = func(string) common.Dialer {
    86  		return NewCustomTLSDialer(
    87  			&CustomTLSConfig{
    88  				Parameters:               params,
    89  				Dial:                     NewTCPDialer(&DialConfig{ResolveIP: resolveIP}),
    90  				RandomizedTLSProfileSeed: seed,
    91  			})
    92  	}
    93  
    94  	dialGoroutineFunctionNames := []string{"NewTCPDialer", "NewCustomTLSDialer"}
    95  
    96  	for dialerName, makeDialer := range makeDialers {
    97  		for _, doTimeout := range []bool{true, false} {
    98  			t.Run(
    99  				fmt.Sprintf("%s-timeout-%+v", dialerName, doTimeout),
   100  				func(t *testing.T) {
   101  					runInterruptDials(
   102  						t,
   103  						doTimeout,
   104  						makeDialer,
   105  						dialGoroutineFunctionNames)
   106  				})
   107  		}
   108  	}
   109  
   110  }
   111  
   112  func runInterruptDials(
   113  	t *testing.T,
   114  	doTimeout bool,
   115  	makeDialer func(string) common.Dialer,
   116  	dialGoroutineFunctionNames []string) {
   117  
   118  	t.Logf("Test timeout: %+v", doTimeout)
   119  
   120  	noAcceptListener, err := net.Listen("tcp", "127.0.0.1:0")
   121  	if err != nil {
   122  		t.Fatalf("Listen failed: %s", err)
   123  	}
   124  	defer noAcceptListener.Close()
   125  
   126  	noResponseListener, err := net.Listen("tcp", "127.0.0.1:0")
   127  	if err != nil {
   128  		t.Fatalf("Listen failed: %s", err)
   129  	}
   130  	defer noResponseListener.Close()
   131  
   132  	listenerAccepted := make(chan struct{}, 1)
   133  
   134  	noResponseListenerWaitGroup := new(sync.WaitGroup)
   135  	noResponseListenerWaitGroup.Add(1)
   136  	defer noResponseListenerWaitGroup.Wait()
   137  	go func() {
   138  		defer noResponseListenerWaitGroup.Done()
   139  		for {
   140  			conn, err := noResponseListener.Accept()
   141  			if err != nil {
   142  				return
   143  			}
   144  			listenerAccepted <- struct{}{}
   145  
   146  			var b [1024]byte
   147  			for {
   148  				_, err := conn.Read(b[:])
   149  				if err != nil {
   150  					conn.Close()
   151  					return
   152  				}
   153  			}
   154  		}
   155  	}()
   156  
   157  	var ctx context.Context
   158  	var cancelFunc context.CancelFunc
   159  
   160  	timeout := 100 * time.Millisecond
   161  
   162  	if doTimeout {
   163  		ctx, cancelFunc = context.WithTimeout(context.Background(), timeout)
   164  	} else {
   165  		ctx, cancelFunc = context.WithCancel(context.Background())
   166  	}
   167  
   168  	addrs := []string{
   169  		noAcceptListener.Addr().String(),
   170  		noResponseListener.Addr().String()}
   171  
   172  	dialTerminated := make(chan struct{}, len(addrs))
   173  
   174  	for _, addr := range addrs {
   175  		go func(addr string) {
   176  			conn, err := makeDialer(addr)(ctx, "tcp", addr)
   177  			if err == nil {
   178  				conn.Close()
   179  			}
   180  			dialTerminated <- struct{}{}
   181  		}(addr)
   182  	}
   183  
   184  	// Wait for noResponseListener to accept to ensure that we exercise
   185  	// post-TCP-dial interruption in the case of TLS and proxy dialers that
   186  	// do post-TCP-dial handshake I/O as part of their dial.
   187  
   188  	<-listenerAccepted
   189  
   190  	if doTimeout {
   191  		time.Sleep(timeout)
   192  		defer cancelFunc()
   193  	} else {
   194  		// No timeout, so interrupt with cancel
   195  		cancelFunc()
   196  	}
   197  
   198  	startWaiting := time.Now()
   199  
   200  	for range addrs {
   201  		<-dialTerminated
   202  	}
   203  
   204  	// Test: dial interrupt must complete quickly
   205  
   206  	interruptDuration := time.Since(startWaiting)
   207  
   208  	if interruptDuration > 100*time.Millisecond {
   209  		t.Fatalf("interrupt duration too long: %s", interruptDuration)
   210  	}
   211  
   212  	// Test: interrupted dialers must not leave goroutines running
   213  
   214  	if findGoroutines(t, dialGoroutineFunctionNames) {
   215  		t.Fatalf("unexpected dial goroutines")
   216  	}
   217  }
   218  
   219  func findGoroutines(t *testing.T, targets []string) bool {
   220  	n, _ := runtime.GoroutineProfile(nil)
   221  	r := make([]runtime.StackRecord, n)
   222  	runtime.GoroutineProfile(r)
   223  	found := false
   224  	for _, g := range r {
   225  		stack := g.Stack()
   226  		funcNames := make([]string, len(stack))
   227  		for i := 0; i < len(stack); i++ {
   228  			funcNames[i] = getFunctionName(stack[i])
   229  		}
   230  		s := strings.Join(funcNames, ", ")
   231  		for _, target := range targets {
   232  			if strings.Contains(s, target) {
   233  				t.Logf("found dial goroutine: %s", s)
   234  				found = true
   235  			}
   236  		}
   237  	}
   238  	return found
   239  }
   240  
   241  func getFunctionName(pc uintptr) string {
   242  	funcName := runtime.FuncForPC(pc).Name()
   243  	index := strings.LastIndex(funcName, "/")
   244  	if index != -1 {
   245  		funcName = funcName[index+1:]
   246  	}
   247  	return funcName
   248  }