golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/netutil/listen_test.go (about)

     1  // Copyright 2013 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package netutil
     6  
     7  import (
     8  	"context"
     9  	"errors"
    10  	"io"
    11  	"net"
    12  	"sync"
    13  	"sync/atomic"
    14  	"testing"
    15  	"time"
    16  )
    17  
    18  func TestLimitListenerOverload(t *testing.T) {
    19  	const (
    20  		max      = 5
    21  		attempts = max * 2
    22  		msg      = "bye\n"
    23  	)
    24  
    25  	l, err := net.Listen("tcp", "127.0.0.1:0")
    26  	if err != nil {
    27  		t.Fatal(err)
    28  	}
    29  	l = LimitListener(l, max)
    30  
    31  	var wg sync.WaitGroup
    32  	wg.Add(1)
    33  	saturated := make(chan struct{})
    34  	go func() {
    35  		defer wg.Done()
    36  
    37  		accepted := 0
    38  		for {
    39  			c, err := l.Accept()
    40  			if err != nil {
    41  				break
    42  			}
    43  			accepted++
    44  			if accepted == max {
    45  				close(saturated)
    46  			}
    47  			io.WriteString(c, msg)
    48  
    49  			// Leave c open until the listener is closed.
    50  			defer c.Close()
    51  		}
    52  		t.Logf("with limit %d, accepted %d simultaneous connections", max, accepted)
    53  		// The listener accounts open connections based on Listener-side Close
    54  		// calls, so even if the client hangs up early (for example, because it
    55  		// was a random dial from another process instead of from this test), we
    56  		// should not end up accepting more connections than expected.
    57  		if accepted != max {
    58  			t.Errorf("want exactly %d", max)
    59  		}
    60  	}()
    61  
    62  	dialCtx, cancelDial := context.WithCancel(context.Background())
    63  	defer cancelDial()
    64  	dialer := &net.Dialer{}
    65  
    66  	var dialed, served int32
    67  	var pendingDials sync.WaitGroup
    68  	for n := attempts; n > 0; n-- {
    69  		wg.Add(1)
    70  		pendingDials.Add(1)
    71  		go func() {
    72  			defer wg.Done()
    73  
    74  			c, err := dialer.DialContext(dialCtx, l.Addr().Network(), l.Addr().String())
    75  			pendingDials.Done()
    76  			if err != nil {
    77  				t.Log(err)
    78  				return
    79  			}
    80  			atomic.AddInt32(&dialed, 1)
    81  			defer c.Close()
    82  
    83  			// The kernel may queue more than max connections (allowing their dials to
    84  			// succeed), but only max of them should actually be accepted by the
    85  			// server. We can distinguish the two based on whether the listener writes
    86  			// anything to the connection — a connection that was queued but not
    87  			// accepted will be closed without transferring any data.
    88  			if b, err := io.ReadAll(c); len(b) < len(msg) {
    89  				t.Log(err)
    90  				return
    91  			}
    92  			atomic.AddInt32(&served, 1)
    93  		}()
    94  	}
    95  
    96  	// Give the server a bit of time after it saturates to make sure it doesn't
    97  	// exceed its limit after serving this connection, then cancel the remaining
    98  	// dials (if any).
    99  	<-saturated
   100  	time.Sleep(10 * time.Millisecond)
   101  	cancelDial()
   102  	// Wait for the dials to complete to ensure that the port isn't reused before
   103  	// the dials are actually attempted.
   104  	pendingDials.Wait()
   105  	l.Close()
   106  	wg.Wait()
   107  
   108  	t.Logf("served %d simultaneous connections (of %d dialed, %d attempted)", served, dialed, attempts)
   109  
   110  	// If some other process (such as a port scan or another test) happens to dial
   111  	// the listener at the same time, the listener could end up burning its quota
   112  	// on that, resulting in fewer than max test connections being served.
   113  	// But the number served certainly cannot be greater.
   114  	if served > max {
   115  		t.Errorf("expected at most %d served", max)
   116  	}
   117  }
   118  
   119  func TestLimitListenerSaturation(t *testing.T) {
   120  	const (
   121  		max             = 5
   122  		attemptsPerWave = max * 2
   123  		waves           = 10
   124  		msg             = "bye\n"
   125  	)
   126  
   127  	l, err := net.Listen("tcp", "127.0.0.1:0")
   128  	if err != nil {
   129  		t.Fatal(err)
   130  	}
   131  	l = LimitListener(l, max)
   132  
   133  	acceptDone := make(chan struct{})
   134  	defer func() {
   135  		l.Close()
   136  		<-acceptDone
   137  	}()
   138  	go func() {
   139  		defer close(acceptDone)
   140  
   141  		var open, peakOpen int32
   142  		var (
   143  			saturated     = make(chan struct{})
   144  			saturatedOnce sync.Once
   145  		)
   146  		var wg sync.WaitGroup
   147  		for {
   148  			c, err := l.Accept()
   149  			if err != nil {
   150  				break
   151  			}
   152  			if n := atomic.AddInt32(&open, 1); n > peakOpen {
   153  				peakOpen = n
   154  				if n == max {
   155  					saturatedOnce.Do(func() {
   156  						// Wait a bit to make sure the listener doesn't exceed its limit
   157  						// after accepting this connection, then allow the in-flight
   158  						// connections to write out and close.
   159  						time.AfterFunc(10*time.Millisecond, func() { close(saturated) })
   160  					})
   161  				}
   162  			}
   163  			wg.Add(1)
   164  			go func() {
   165  				<-saturated
   166  				io.WriteString(c, msg)
   167  				atomic.AddInt32(&open, -1)
   168  				c.Close()
   169  				wg.Done()
   170  			}()
   171  		}
   172  		wg.Wait()
   173  
   174  		t.Logf("with limit %d, accepted a peak of %d simultaneous connections", max, peakOpen)
   175  		if peakOpen > max {
   176  			t.Errorf("want at most %d", max)
   177  		}
   178  	}()
   179  
   180  	for wave := 0; wave < waves; wave++ {
   181  		var dialed, served int32
   182  		var wg sync.WaitGroup
   183  		for n := attemptsPerWave; n > 0; n-- {
   184  			wg.Add(1)
   185  			go func() {
   186  				defer wg.Done()
   187  
   188  				c, err := net.Dial(l.Addr().Network(), l.Addr().String())
   189  				if err != nil {
   190  					t.Log(err)
   191  					return
   192  				}
   193  				atomic.AddInt32(&dialed, 1)
   194  				defer c.Close()
   195  
   196  				if b, err := io.ReadAll(c); len(b) < len(msg) {
   197  					t.Log(err)
   198  					return
   199  				}
   200  				atomic.AddInt32(&served, 1)
   201  			}()
   202  		}
   203  		wg.Wait()
   204  
   205  		t.Logf("served %d connections (of %d dialed, %d attempted)", served, dialed, attemptsPerWave)
   206  
   207  		// Depending on the kernel's queueing behavior, we could get unlucky
   208  		// and drop one or more connections. However, we should certainly
   209  		// be able to serve at least max attempts out of each wave.
   210  		// (In the typical case, the kernel will queue all of the connections
   211  		// and they will all be served successfully.)
   212  		if dialed < max {
   213  			t.Errorf("expected at least %d dialed", max)
   214  		}
   215  		if served < dialed {
   216  			t.Errorf("expected all dialed connections to be served")
   217  		}
   218  	}
   219  }
   220  
   221  type errorListener struct {
   222  	net.Listener
   223  }
   224  
   225  func (errorListener) Accept() (net.Conn, error) {
   226  	return nil, errFake
   227  }
   228  
   229  var errFake = errors.New("fake error from errorListener")
   230  
   231  // This used to hang.
   232  func TestLimitListenerError(t *testing.T) {
   233  	const n = 2
   234  	ll := LimitListener(errorListener{}, n)
   235  	for i := 0; i < n+1; i++ {
   236  		_, err := ll.Accept()
   237  		if err != errFake {
   238  			t.Fatalf("Accept error = %v; want errFake", err)
   239  		}
   240  	}
   241  }
   242  
   243  func TestLimitListenerClose(t *testing.T) {
   244  	ln, err := net.Listen("tcp", "127.0.0.1:0")
   245  	if err != nil {
   246  		t.Fatal(err)
   247  	}
   248  	defer ln.Close()
   249  	ln = LimitListener(ln, 1)
   250  
   251  	errCh := make(chan error)
   252  	go func() {
   253  		defer close(errCh)
   254  		c, err := net.Dial(ln.Addr().Network(), ln.Addr().String())
   255  		if err != nil {
   256  			errCh <- err
   257  			return
   258  		}
   259  		c.Close()
   260  	}()
   261  
   262  	c, err := ln.Accept()
   263  	if err != nil {
   264  		t.Fatal(err)
   265  	}
   266  	defer c.Close()
   267  
   268  	err = <-errCh
   269  	if err != nil {
   270  		t.Fatalf("Dial: %v", err)
   271  	}
   272  
   273  	// Allow the subsequent Accept to block before closing the listener.
   274  	// (Accept should unblock and return.)
   275  	timer := time.AfterFunc(10*time.Millisecond, func() { ln.Close() })
   276  
   277  	c, err = ln.Accept()
   278  	if err == nil {
   279  		c.Close()
   280  		t.Errorf("Unexpected successful Accept()")
   281  	}
   282  	if timer.Stop() {
   283  		t.Errorf("Accept returned before listener closed: %v", err)
   284  	}
   285  }