golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/netutil/listen_test.go (about) 1 // Copyright 2013 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package netutil 6 7 import ( 8 "context" 9 "errors" 10 "io" 11 "net" 12 "sync" 13 "sync/atomic" 14 "testing" 15 "time" 16 ) 17 18 func TestLimitListenerOverload(t *testing.T) { 19 const ( 20 max = 5 21 attempts = max * 2 22 msg = "bye\n" 23 ) 24 25 l, err := net.Listen("tcp", "127.0.0.1:0") 26 if err != nil { 27 t.Fatal(err) 28 } 29 l = LimitListener(l, max) 30 31 var wg sync.WaitGroup 32 wg.Add(1) 33 saturated := make(chan struct{}) 34 go func() { 35 defer wg.Done() 36 37 accepted := 0 38 for { 39 c, err := l.Accept() 40 if err != nil { 41 break 42 } 43 accepted++ 44 if accepted == max { 45 close(saturated) 46 } 47 io.WriteString(c, msg) 48 49 // Leave c open until the listener is closed. 50 defer c.Close() 51 } 52 t.Logf("with limit %d, accepted %d simultaneous connections", max, accepted) 53 // The listener accounts open connections based on Listener-side Close 54 // calls, so even if the client hangs up early (for example, because it 55 // was a random dial from another process instead of from this test), we 56 // should not end up accepting more connections than expected. 57 if accepted != max { 58 t.Errorf("want exactly %d", max) 59 } 60 }() 61 62 dialCtx, cancelDial := context.WithCancel(context.Background()) 63 defer cancelDial() 64 dialer := &net.Dialer{} 65 66 var dialed, served int32 67 var pendingDials sync.WaitGroup 68 for n := attempts; n > 0; n-- { 69 wg.Add(1) 70 pendingDials.Add(1) 71 go func() { 72 defer wg.Done() 73 74 c, err := dialer.DialContext(dialCtx, l.Addr().Network(), l.Addr().String()) 75 pendingDials.Done() 76 if err != nil { 77 t.Log(err) 78 return 79 } 80 atomic.AddInt32(&dialed, 1) 81 defer c.Close() 82 83 // The kernel may queue more than max connections (allowing their dials to 84 // succeed), but only max of them should actually be accepted by the 85 // server. We can distinguish the two based on whether the listener writes 86 // anything to the connection — a connection that was queued but not 87 // accepted will be closed without transferring any data. 88 if b, err := io.ReadAll(c); len(b) < len(msg) { 89 t.Log(err) 90 return 91 } 92 atomic.AddInt32(&served, 1) 93 }() 94 } 95 96 // Give the server a bit of time after it saturates to make sure it doesn't 97 // exceed its limit after serving this connection, then cancel the remaining 98 // dials (if any). 99 <-saturated 100 time.Sleep(10 * time.Millisecond) 101 cancelDial() 102 // Wait for the dials to complete to ensure that the port isn't reused before 103 // the dials are actually attempted. 104 pendingDials.Wait() 105 l.Close() 106 wg.Wait() 107 108 t.Logf("served %d simultaneous connections (of %d dialed, %d attempted)", served, dialed, attempts) 109 110 // If some other process (such as a port scan or another test) happens to dial 111 // the listener at the same time, the listener could end up burning its quota 112 // on that, resulting in fewer than max test connections being served. 113 // But the number served certainly cannot be greater. 114 if served > max { 115 t.Errorf("expected at most %d served", max) 116 } 117 } 118 119 func TestLimitListenerSaturation(t *testing.T) { 120 const ( 121 max = 5 122 attemptsPerWave = max * 2 123 waves = 10 124 msg = "bye\n" 125 ) 126 127 l, err := net.Listen("tcp", "127.0.0.1:0") 128 if err != nil { 129 t.Fatal(err) 130 } 131 l = LimitListener(l, max) 132 133 acceptDone := make(chan struct{}) 134 defer func() { 135 l.Close() 136 <-acceptDone 137 }() 138 go func() { 139 defer close(acceptDone) 140 141 var open, peakOpen int32 142 var ( 143 saturated = make(chan struct{}) 144 saturatedOnce sync.Once 145 ) 146 var wg sync.WaitGroup 147 for { 148 c, err := l.Accept() 149 if err != nil { 150 break 151 } 152 if n := atomic.AddInt32(&open, 1); n > peakOpen { 153 peakOpen = n 154 if n == max { 155 saturatedOnce.Do(func() { 156 // Wait a bit to make sure the listener doesn't exceed its limit 157 // after accepting this connection, then allow the in-flight 158 // connections to write out and close. 159 time.AfterFunc(10*time.Millisecond, func() { close(saturated) }) 160 }) 161 } 162 } 163 wg.Add(1) 164 go func() { 165 <-saturated 166 io.WriteString(c, msg) 167 atomic.AddInt32(&open, -1) 168 c.Close() 169 wg.Done() 170 }() 171 } 172 wg.Wait() 173 174 t.Logf("with limit %d, accepted a peak of %d simultaneous connections", max, peakOpen) 175 if peakOpen > max { 176 t.Errorf("want at most %d", max) 177 } 178 }() 179 180 for wave := 0; wave < waves; wave++ { 181 var dialed, served int32 182 var wg sync.WaitGroup 183 for n := attemptsPerWave; n > 0; n-- { 184 wg.Add(1) 185 go func() { 186 defer wg.Done() 187 188 c, err := net.Dial(l.Addr().Network(), l.Addr().String()) 189 if err != nil { 190 t.Log(err) 191 return 192 } 193 atomic.AddInt32(&dialed, 1) 194 defer c.Close() 195 196 if b, err := io.ReadAll(c); len(b) < len(msg) { 197 t.Log(err) 198 return 199 } 200 atomic.AddInt32(&served, 1) 201 }() 202 } 203 wg.Wait() 204 205 t.Logf("served %d connections (of %d dialed, %d attempted)", served, dialed, attemptsPerWave) 206 207 // Depending on the kernel's queueing behavior, we could get unlucky 208 // and drop one or more connections. However, we should certainly 209 // be able to serve at least max attempts out of each wave. 210 // (In the typical case, the kernel will queue all of the connections 211 // and they will all be served successfully.) 212 if dialed < max { 213 t.Errorf("expected at least %d dialed", max) 214 } 215 if served < dialed { 216 t.Errorf("expected all dialed connections to be served") 217 } 218 } 219 } 220 221 type errorListener struct { 222 net.Listener 223 } 224 225 func (errorListener) Accept() (net.Conn, error) { 226 return nil, errFake 227 } 228 229 var errFake = errors.New("fake error from errorListener") 230 231 // This used to hang. 232 func TestLimitListenerError(t *testing.T) { 233 const n = 2 234 ll := LimitListener(errorListener{}, n) 235 for i := 0; i < n+1; i++ { 236 _, err := ll.Accept() 237 if err != errFake { 238 t.Fatalf("Accept error = %v; want errFake", err) 239 } 240 } 241 } 242 243 func TestLimitListenerClose(t *testing.T) { 244 ln, err := net.Listen("tcp", "127.0.0.1:0") 245 if err != nil { 246 t.Fatal(err) 247 } 248 defer ln.Close() 249 ln = LimitListener(ln, 1) 250 251 errCh := make(chan error) 252 go func() { 253 defer close(errCh) 254 c, err := net.Dial(ln.Addr().Network(), ln.Addr().String()) 255 if err != nil { 256 errCh <- err 257 return 258 } 259 c.Close() 260 }() 261 262 c, err := ln.Accept() 263 if err != nil { 264 t.Fatal(err) 265 } 266 defer c.Close() 267 268 err = <-errCh 269 if err != nil { 270 t.Fatalf("Dial: %v", err) 271 } 272 273 // Allow the subsequent Accept to block before closing the listener. 274 // (Accept should unblock and return.) 275 timer := time.AfterFunc(10*time.Millisecond, func() { ln.Close() }) 276 277 c, err = ln.Accept() 278 if err == nil { 279 c.Close() 280 t.Errorf("Unexpected successful Accept()") 281 } 282 if timer.Stop() { 283 t.Errorf("Accept returned before listener closed: %v", err) 284 } 285 }