github.com/psiphon-labs/psiphon-tunnel-core@v2.0.28+incompatible/psiphon/common/quic/quic_test.go (about)

     1  //go:build !PSIPHON_DISABLE_QUIC
     2  // +build !PSIPHON_DISABLE_QUIC
     3  
     4  /*
     5   * Copyright (c) 2018, Psiphon Inc.
     6   * All rights reserved.
     7   *
     8   * This program is free software: you can redistribute it and/or modify
     9   * it under the terms of the GNU General Public License as published by
    10   * the Free Software Foundation, either version 3 of the License, or
    11   * (at your option) any later version.
    12   *
    13   * This program is distributed in the hope that it will be useful,
    14   * but WITHOUT ANY WARRANTY; without even the implied warranty of
    15   * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    16   * GNU General Public License for more details.
    17   *
    18   * You should have received a copy of the GNU General Public License
    19   * along with this program.  If not, see <http://www.gnu.org/licenses/>.
    20   *
    21   */
    22  
    23  package quic
    24  
    25  import (
    26  	"context"
    27  	"fmt"
    28  	"io"
    29  	"net"
    30  	"runtime"
    31  	"strings"
    32  	"sync/atomic"
    33  	"testing"
    34  	"time"
    35  
    36  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common"
    37  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/errors"
    38  	"github.com/Psiphon-Labs/psiphon-tunnel-core/psiphon/common/prng"
    39  	"golang.org/x/sync/errgroup"
    40  )
    41  
    42  func TestQUIC(t *testing.T) {
    43  	for quicVersion := range supportedVersionNumbers {
    44  		t.Run(fmt.Sprintf("%s", quicVersion), func(t *testing.T) {
    45  			if isGQUIC(quicVersion) && !GQUICEnabled() {
    46  				t.Skipf("gQUIC is not enabled")
    47  			}
    48  			runQUIC(t, quicVersion, GQUICEnabled(), false)
    49  		})
    50  		if isIETF(quicVersion) {
    51  			t.Run(fmt.Sprintf("%s (invoke anti-probing)", quicVersion), func(t *testing.T) {
    52  				runQUIC(t, quicVersion, GQUICEnabled(), true)
    53  			})
    54  		}
    55  		if isIETF(quicVersion) {
    56  			t.Run(fmt.Sprintf("%s (disable gQUIC)", quicVersion), func(t *testing.T) {
    57  				runQUIC(t, quicVersion, false, false)
    58  			})
    59  		}
    60  	}
    61  }
    62  
    63  func runQUIC(
    64  	t *testing.T,
    65  	quicVersion string,
    66  	enableGQUIC bool,
    67  	invokeAntiProbing bool) {
    68  
    69  	initGoroutines := getGoroutines()
    70  
    71  	clients := 10
    72  	bytesToSend := 1 << 20
    73  
    74  	serverReceivedBytes := int64(0)
    75  	clientReceivedBytes := int64(0)
    76  
    77  	// Intermittently, on some platforms, the client connection termination
    78  	// packet is not received even when sent/received locally; set a brief
    79  	// idle timeout to ensure the server-side client handler doesn't block too
    80  	// long on Read, causing the test to fail.
    81  	//
    82  	// In realistic network conditions, and especially under adversarial
    83  	// network conditions, we should not expect to regularly receive client
    84  	// connection termination packets.
    85  	serverIdleTimeout = 1 * time.Second
    86  
    87  	irregularTunnelLogger := func(_ string, err error, _ common.LogFields) {
    88  		if !invokeAntiProbing {
    89  			t.Errorf("unexpected irregular tunnel event: %v", err)
    90  		}
    91  	}
    92  
    93  	obfuscationKey := prng.HexString(32)
    94  
    95  	listener, err := Listen(
    96  		nil,
    97  		irregularTunnelLogger,
    98  		"127.0.0.1:0",
    99  		obfuscationKey,
   100  		enableGQUIC)
   101  	if err != nil {
   102  		t.Fatalf("Listen failed: %s", err)
   103  	}
   104  
   105  	serverAddress := listener.Addr().String()
   106  
   107  	testGroup, testCtx := errgroup.WithContext(context.Background())
   108  
   109  	testGroup.Go(func() error {
   110  
   111  		if invokeAntiProbing {
   112  			// The quic-go server can still handshake new sessions even if
   113  			// Accept isn't called.
   114  			return nil
   115  		}
   116  
   117  		var serverGroup errgroup.Group
   118  
   119  		for i := 0; i < clients; i++ {
   120  
   121  			conn, err := listener.Accept()
   122  			if err != nil {
   123  				return errors.Trace(err)
   124  			}
   125  
   126  			serverGroup.Go(func() error {
   127  				b := make([]byte, 1024)
   128  				for {
   129  					n, err := conn.Read(b)
   130  					atomic.AddInt64(&serverReceivedBytes, int64(n))
   131  					if err == io.EOF {
   132  						return nil
   133  					} else if err != nil {
   134  						return errors.Trace(err)
   135  					}
   136  					_, err = conn.Write(b[:n])
   137  					if err != nil {
   138  						return errors.Trace(err)
   139  					}
   140  				}
   141  			})
   142  		}
   143  
   144  		err := serverGroup.Wait()
   145  		if err != nil {
   146  			return errors.Trace(err)
   147  		}
   148  
   149  		return nil
   150  	})
   151  
   152  	for i := 0; i < clients; i++ {
   153  
   154  		disablePathMTUDiscovery := i%2 == 0
   155  
   156  		testGroup.Go(func() error {
   157  
   158  			ctx, cancelFunc := context.WithTimeout(
   159  				context.Background(), 1*time.Second)
   160  			defer cancelFunc()
   161  
   162  			remoteAddr, err := net.ResolveUDPAddr("udp", serverAddress)
   163  			if err != nil {
   164  				return errors.Trace(err)
   165  			}
   166  
   167  			packetConn, err := net.ListenPacket("udp4", "127.0.0.1:0")
   168  			if err != nil {
   169  				return errors.Trace(err)
   170  			}
   171  
   172  			clientObfuscationKey := obfuscationKey
   173  			if invokeAntiProbing {
   174  				clientObfuscationKey = prng.HexString(32)
   175  				packetConn = &countReadsConn{PacketConn: packetConn}
   176  			}
   177  
   178  			obfuscationPaddingSeed, err := prng.NewSeed()
   179  			if err != nil {
   180  				return errors.Trace(err)
   181  			}
   182  
   183  			var clientHelloSeed *prng.Seed
   184  			if isClientHelloRandomized(quicVersion) {
   185  				clientHelloSeed, err = prng.NewSeed()
   186  				if err != nil {
   187  					return errors.Trace(err)
   188  				}
   189  			}
   190  
   191  			conn, err := Dial(
   192  				ctx,
   193  				packetConn,
   194  				remoteAddr,
   195  				serverAddress,
   196  				quicVersion,
   197  				clientHelloSeed,
   198  				clientObfuscationKey,
   199  				obfuscationPaddingSeed,
   200  				disablePathMTUDiscovery)
   201  
   202  			if invokeAntiProbing {
   203  
   204  				if err == nil {
   205  					return errors.TraceNew(
   206  						"unexpected dial success with invalid client hello random")
   207  				}
   208  
   209  				readCount := packetConn.(*countReadsConn).getReadCount()
   210  
   211  				if readCount > 0 {
   212  					return errors.Tracef(
   213  						"unexpected %d read packets with invalid client hello random",
   214  						readCount)
   215  				}
   216  
   217  				return nil
   218  			}
   219  
   220  			if err != nil {
   221  				return errors.Trace(err)
   222  			}
   223  
   224  			// Cancel should interrupt dialing only
   225  			cancelFunc()
   226  
   227  			var clientGroup errgroup.Group
   228  
   229  			clientGroup.Go(func() error {
   230  				defer conn.Close()
   231  				b := make([]byte, 1024)
   232  				bytesRead := 0
   233  				for bytesRead < bytesToSend {
   234  					n, err := conn.Read(b)
   235  					bytesRead += n
   236  					atomic.AddInt64(&clientReceivedBytes, int64(n))
   237  					if err == io.EOF {
   238  						break
   239  					} else if err != nil {
   240  						return errors.Trace(err)
   241  					}
   242  				}
   243  				return nil
   244  			})
   245  
   246  			clientGroup.Go(func() error {
   247  				b := make([]byte, bytesToSend)
   248  				_, err := conn.Write(b)
   249  				if err != nil {
   250  					return errors.Trace(err)
   251  				}
   252  				return nil
   253  			})
   254  
   255  			return clientGroup.Wait()
   256  		})
   257  
   258  	}
   259  
   260  	go func() {
   261  		testGroup.Wait()
   262  	}()
   263  
   264  	<-testCtx.Done()
   265  	listener.Close()
   266  
   267  	err = testGroup.Wait()
   268  	if err != nil {
   269  		t.Errorf("goroutine failed: %s", err)
   270  	}
   271  
   272  	bytes := atomic.LoadInt64(&serverReceivedBytes)
   273  	expectedBytes := int64(clients * bytesToSend)
   274  	if invokeAntiProbing {
   275  		expectedBytes = 0
   276  	}
   277  	if bytes != expectedBytes {
   278  		t.Errorf("unexpected serverReceivedBytes: %d vs. %d", bytes, expectedBytes)
   279  	}
   280  
   281  	bytes = atomic.LoadInt64(&clientReceivedBytes)
   282  	if bytes != expectedBytes {
   283  		t.Errorf("unexpected clientReceivedBytes: %d vs. %d", bytes, expectedBytes)
   284  	}
   285  
   286  	_, err = listener.Accept()
   287  	if err == nil {
   288  		t.Error("unexpected Accept after Close")
   289  	}
   290  
   291  	// Check for unexpected dangling goroutines after shutdown.
   292  	//
   293  	// quic-go.packetHandlerMap.listen shutdown is async and some quic-go
   294  	// goroutines and/or timers dangle so this test makes allowances for these
   295  	// known dangling goroutinees.
   296  
   297  	expectedDanglingGoroutines := []string{
   298  		"quic-go.(*packetHandlerMap).Retire.func1",
   299  		"quic-go.(*packetHandlerMap).ReplaceWithClosed.func1",
   300  		"quic-go.(*packetHandlerMap).RetireResetToken.func1",
   301  		"gquic-go.(*packetHandlerMap).removeByConnectionIDAsString.func1",
   302  	}
   303  
   304  	sleepTime := 100 * time.Millisecond
   305  
   306  	// The longest expected dangling goroutine is in gquic-go and is launched by a timer
   307  	// that fires after ClosedSessionDeleteTimeout, which is 1m. Allow one extra second
   308  	// to ensure this period elapses and the time.AfterFunc runs.
   309  	//
   310  	// To avoid taking 1m to run this test every time, the dangling goroutine check exits
   311  	// early once no dangling goroutines are found. Note that this doesn't account for
   312  	// any timers still pending at the early exit time.
   313  	n := int((61 * time.Second) / sleepTime)
   314  
   315  	for i := 0; i < n; i++ {
   316  
   317  		// Sleep before making any checks, since quic-go.packetHandlerMap.listen
   318  		// shutdown is asynchronous.
   319  		time.Sleep(100 * time.Millisecond)
   320  
   321  		// After the full 61s, no dangling goroutines are expected.
   322  		if i == n-1 {
   323  			expectedDanglingGoroutines = []string{}
   324  		}
   325  
   326  		hasDangling, onlyExpectedDangling := checkDanglingGoroutines(
   327  			t, initGoroutines, expectedDanglingGoroutines)
   328  		if !hasDangling {
   329  			break
   330  		} else if !onlyExpectedDangling {
   331  			t.Fatalf("unexpected dangling goroutines")
   332  		}
   333  	}
   334  }
   335  
   336  type countReadsConn struct {
   337  	net.PacketConn
   338  	readCount int32
   339  }
   340  
   341  func (conn *countReadsConn) ReadFrom(p []byte) (int, net.Addr, error) {
   342  	n, addr, err := conn.PacketConn.ReadFrom(p)
   343  	if n > 0 {
   344  		atomic.AddInt32(&conn.readCount, 1)
   345  	}
   346  	return n, addr, err
   347  }
   348  
   349  func (conn *countReadsConn) getReadCount() int {
   350  	return int(atomic.LoadInt32(&conn.readCount))
   351  }
   352  
   353  func getGoroutines() []runtime.StackRecord {
   354  	n, _ := runtime.GoroutineProfile(nil)
   355  	r := make([]runtime.StackRecord, n)
   356  	runtime.GoroutineProfile(r)
   357  	return r
   358  }
   359  
   360  func checkDanglingGoroutines(
   361  	t *testing.T,
   362  	initGoroutines []runtime.StackRecord,
   363  	expectedDanglingGoroutines []string) (bool, bool) {
   364  
   365  	hasDangling := false
   366  	onlyExpectedDangling := true
   367  	current := getGoroutines()
   368  	for _, g := range current {
   369  		found := false
   370  		for _, h := range initGoroutines {
   371  			if g == h {
   372  				found = true
   373  				break
   374  			}
   375  		}
   376  		if !found {
   377  			stack := g.Stack()
   378  			funcNames := make([]string, len(stack))
   379  			skip := false
   380  			isExpected := false
   381  			for i := 0; i < len(stack); i++ {
   382  				funcNames[i] = getFunctionName(stack[i])
   383  
   384  				// The current goroutine won't have the same stack as in initGoroutines.
   385  				if strings.Contains(funcNames[i], "checkDanglingGoroutines") {
   386  					skip = true
   387  					break
   388  				}
   389  
   390  				// testing.T.Run runs the the test function, f, in another goroutine. f is
   391  				// the current goroutine, which captures initGoroutines.
   392  				// https://github.com/golang/go/blob/release-branch.go1.13/src/testing/testing.go#L960-L961:
   393  				//
   394  				//     go tRunner(t, f)
   395  				//     if !<-t.signal {
   396  				//     ...
   397  				//
   398  				// f may capture initGoroutines before or after testing.T.Run advances to
   399  				// the channel receive, so the stack of the testing.T.Run goroutine may or
   400  				// may not match initGoroutines. Skip it.
   401  				if strings.Contains(funcNames[i], "testing.(*T).Run") {
   402  					skip = true
   403  					break
   404  				}
   405  
   406  				// This goroutine, created by Listener.clientRandomHistory,
   407  				// terminates nondeterministically, based on garbage
   408  				// collection. Skip it.
   409  				if strings.Contains(funcNames[i], "go-cache-lru.(*janitor).Run") {
   410  					skip = true
   411  					break
   412  				}
   413  
   414  				for _, expected := range expectedDanglingGoroutines {
   415  					if strings.Contains(funcNames[i], expected) {
   416  						isExpected = true
   417  						break
   418  					}
   419  				}
   420  				if isExpected {
   421  					break
   422  				}
   423  			}
   424  			if !skip {
   425  				hasDangling = true
   426  				if !isExpected {
   427  					onlyExpectedDangling = false
   428  					s := strings.Join(funcNames, " <- ")
   429  					t.Logf("found unexpected dangling goroutine: %s", s)
   430  				}
   431  			}
   432  		}
   433  	}
   434  	return hasDangling, onlyExpectedDangling
   435  }
   436  
   437  func getFunctionName(pc uintptr) string {
   438  	funcName := runtime.FuncForPC(pc).Name()
   439  	index := strings.LastIndex(funcName, "/")
   440  	if index != -1 {
   441  		funcName = funcName[index+1:]
   442  	}
   443  	return funcName
   444  }