github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/interrupt_dials_test.go (about) 1 /* 2 * Copyright (c) 2017, Psiphon Inc. 3 * All rights reserved. 4 * 5 * This program is free software: you can redistribute it and/or modify 6 * it under the terms of the GNU General Public License as published by 7 * the Free Software Foundation, either version 3 of the License, or 8 * (at your option) any later version. 9 * 10 * This program is distributed in the hope that it will be useful, 11 * but WITHOUT ANY WARRANTY; without even the implied warranty of 12 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 * GNU General Public License for more details. 14 * 15 * You should have received a copy of the GNU General Public License 16 * along with this program. If not, see <http://www.gnu.org/licenses/>. 17 * 18 */ 19 20 package psiphon 21 22 import ( 23 "context" 24 "fmt" 25 "net" 26 "runtime" 27 "strings" 28 "sync" 29 "testing" 30 "time" 31 32 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common" 33 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/parameters" 34 "github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng" 35 ) 36 37 func TestInterruptDials(t *testing.T) { 38 39 resolveIP := func(_ context.Context, host string) ([]net.IP, error) { 40 return []net.IP{net.ParseIP(host)}, nil 41 } 42 43 makeDialers := make(map[string]func(string) common.Dialer) 44 45 makeDialers["TCP"] = func(string) common.Dialer { 46 return NewTCPDialer(&DialConfig{ResolveIP: resolveIP}) 47 } 48 49 makeDialers["SOCKS4-Proxied"] = func(mockServerAddr string) common.Dialer { 50 return NewTCPDialer( 51 &DialConfig{ 52 ResolveIP: resolveIP, 53 UpstreamProxyURL: "socks4a://" + mockServerAddr, 54 }) 55 } 56 57 makeDialers["SOCKS5-Proxied"] = func(mockServerAddr string) common.Dialer { 58 return NewTCPDialer( 59 &DialConfig{ 60 ResolveIP: resolveIP, 61 UpstreamProxyURL: "socks5://" + mockServerAddr, 62 }) 63 } 64 65 makeDialers["HTTP-CONNECT-Proxied"] = func(mockServerAddr string) common.Dialer { 66 return NewTCPDialer( 67 &DialConfig{ 68 ResolveIP: resolveIP, 69 UpstreamProxyURL: "http://" + mockServerAddr, 70 }) 71 } 72 73 // TODO: test upstreamproxy.ProxyAuthTransport 74 75 params, err := parameters.NewParameters(nil) 76 if err != nil { 77 t.Fatalf("NewParameters failed: %s", err) 78 } 79 80 seed, err := prng.NewSeed() 81 if err != nil { 82 t.Fatalf("NewSeed failed: %s", err) 83 } 84 85 makeDialers["TLS"] = func(string) common.Dialer { 86 return NewCustomTLSDialer( 87 &CustomTLSConfig{ 88 Parameters: params, 89 Dial: NewTCPDialer(&DialConfig{ResolveIP: resolveIP}), 90 RandomizedTLSProfileSeed: seed, 91 }) 92 } 93 94 dialGoroutineFunctionNames := []string{"NewTCPDialer", "NewCustomTLSDialer"} 95 96 for dialerName, makeDialer := range makeDialers { 97 for _, doTimeout := range []bool{true, false} { 98 t.Run( 99 fmt.Sprintf("%s-timeout-%+v", dialerName, doTimeout), 100 func(t *testing.T) { 101 runInterruptDials( 102 t, 103 doTimeout, 104 makeDialer, 105 dialGoroutineFunctionNames) 106 }) 107 } 108 } 109 110 } 111 112 func runInterruptDials( 113 t *testing.T, 114 doTimeout bool, 115 makeDialer func(string) common.Dialer, 116 dialGoroutineFunctionNames []string) { 117 118 t.Logf("Test timeout: %+v", doTimeout) 119 120 noAcceptListener, err := net.Listen("tcp", "127.0.0.1:0") 121 if err != nil { 122 t.Fatalf("Listen failed: %s", err) 123 } 124 defer noAcceptListener.Close() 125 126 noResponseListener, err := net.Listen("tcp", "127.0.0.1:0") 127 if err != nil { 128 t.Fatalf("Listen failed: %s", err) 129 } 130 defer noResponseListener.Close() 131 132 listenerAccepted := make(chan struct{}, 1) 133 134 noResponseListenerWaitGroup := new(sync.WaitGroup) 135 noResponseListenerWaitGroup.Add(1) 136 defer noResponseListenerWaitGroup.Wait() 137 go func() { 138 defer noResponseListenerWaitGroup.Done() 139 for { 140 conn, err := noResponseListener.Accept() 141 if err != nil { 142 return 143 } 144 listenerAccepted <- struct{}{} 145 146 var b [1024]byte 147 for { 148 _, err := conn.Read(b[:]) 149 if err != nil { 150 conn.Close() 151 return 152 } 153 } 154 } 155 }() 156 157 var ctx context.Context 158 var cancelFunc context.CancelFunc 159 160 timeout := 100 * time.Millisecond 161 162 if doTimeout { 163 ctx, cancelFunc = context.WithTimeout(context.Background(), timeout) 164 } else { 165 ctx, cancelFunc = context.WithCancel(context.Background()) 166 } 167 168 addrs := []string{ 169 noAcceptListener.Addr().String(), 170 noResponseListener.Addr().String()} 171 172 dialTerminated := make(chan struct{}, len(addrs)) 173 174 for _, addr := range addrs { 175 go func(addr string) { 176 conn, err := makeDialer(addr)(ctx, "tcp", addr) 177 if err == nil { 178 conn.Close() 179 } 180 dialTerminated <- struct{}{} 181 }(addr) 182 } 183 184 // Wait for noResponseListener to accept to ensure that we exercise 185 // post-TCP-dial interruption in the case of TLS and proxy dialers that 186 // do post-TCP-dial handshake I/O as part of their dial. 187 188 <-listenerAccepted 189 190 if doTimeout { 191 time.Sleep(timeout) 192 defer cancelFunc() 193 } else { 194 // No timeout, so interrupt with cancel 195 cancelFunc() 196 } 197 198 startWaiting := time.Now() 199 200 for range addrs { 201 <-dialTerminated 202 } 203 204 // Test: dial interrupt must complete quickly 205 206 interruptDuration := time.Since(startWaiting) 207 208 if interruptDuration > 100*time.Millisecond { 209 t.Fatalf("interrupt duration too long: %s", interruptDuration) 210 } 211 212 // Test: interrupted dialers must not leave goroutines running 213 214 if findGoroutines(t, dialGoroutineFunctionNames) { 215 t.Fatalf("unexpected dial goroutines") 216 } 217 } 218 219 func findGoroutines(t *testing.T, targets []string) bool { 220 n, _ := runtime.GoroutineProfile(nil) 221 r := make([]runtime.StackRecord, n) 222 runtime.GoroutineProfile(r) 223 found := false 224 for _, g := range r { 225 stack := g.Stack() 226 funcNames := make([]string, len(stack)) 227 for i := 0; i < len(stack); i++ { 228 funcNames[i] = getFunctionName(stack[i]) 229 } 230 s := strings.Join(funcNames, ", ") 231 for _, target := range targets { 232 if strings.Contains(s, target) { 233 t.Logf("found dial goroutine: %s", s) 234 found = true 235 } 236 } 237 } 238 return found 239 } 240 241 func getFunctionName(pc uintptr) string { 242 funcName := runtime.FuncForPC(pc).Name() 243 index := strings.LastIndex(funcName, "/") 244 if index != -1 { 245 funcName = funcName[index+1:] 246 } 247 return funcName 248 }