github.com/ezoic/ws@v1.0.4-0.20220713205711-5c1d69e074c5/util_test.go (about) 1 package ws 2 3 import ( 4 "bufio" 5 "bytes" 6 "context" 7 "fmt" 8 "io" 9 "math/rand" 10 "net" 11 "net/http" 12 "net/textproto" 13 "reflect" 14 "strings" 15 "sync" 16 "testing" 17 "time" 18 ) 19 20 var readLineCases = []struct { 21 label string 22 in string 23 line []byte 24 err error 25 bufSize int 26 }{ 27 { 28 label: "simple", 29 in: "hello, world!", 30 line: []byte("hello, world!"), 31 err: io.EOF, 32 bufSize: 1024, 33 }, 34 { 35 label: "simple", 36 in: "hello, world!\r\n", 37 line: []byte("hello, world!"), 38 bufSize: 1024, 39 }, 40 { 41 label: "simple", 42 in: "hello, world!\n", 43 line: []byte("hello, world!"), 44 bufSize: 1024, 45 }, 46 { 47 // The case where "\r\n" straddles the buffer. 48 label: "straddle", 49 in: "hello, world!!!\r\n...", 50 line: []byte("hello, world!!!"), 51 bufSize: 16, 52 }, 53 { 54 label: "chunked", 55 in: "hello, world! this is a long long line!", 56 line: []byte("hello, world! this is a long long line!"), 57 err: io.EOF, 58 bufSize: 16, 59 }, 60 { 61 label: "chunked", 62 in: "hello, world! this is a long long line!\r\n", 63 line: []byte("hello, world! this is a long long line!"), 64 bufSize: 16, 65 }, 66 } 67 68 func TestReadLine(t *testing.T) { 69 for _, test := range readLineCases { 70 t.Run(test.label, func(t *testing.T) { 71 br := bufio.NewReaderSize(strings.NewReader(test.in), test.bufSize) 72 bts, err := readLine(br) 73 if err != test.err { 74 t.Errorf("unexpected error: %v; want %v", err, test.err) 75 } 76 if act, exp := bts, test.line; !bytes.Equal(act, exp) { 77 t.Errorf("readLine() result is %#q; want %#q", act, exp) 78 } 79 }) 80 } 81 } 82 83 func BenchmarkReadLine(b *testing.B) { 84 for _, test := range readLineCases { 85 sr := strings.NewReader(test.in) 86 br := bufio.NewReaderSize(sr, test.bufSize) 87 b.Run(test.label, func(b *testing.B) { 88 for i := 0; i < b.N; i++ { 89 _, _ = readLine(br) 90 sr.Reset(test.in) 91 br.Reset(sr) 92 } 93 }) 94 } 95 } 96 97 func TestUpgradeSlowClient(t *testing.T) { 98 for _, test := range []struct { 99 lim *limitWriter 100 }{ 101 { 102 lim: &limitWriter{ 103 Bandwidth: 100, 104 Period: time.Second, 105 Burst: 10, 106 }, 107 }, 108 { 109 lim: &limitWriter{ 110 Bandwidth: 100, 111 Period: time.Second, 112 Burst: 100, 113 }, 114 }, 115 } { 116 t.Run("", func(t *testing.T) { 117 client, server, err := socketPair() 118 if err != nil { 119 t.Fatal(err) 120 } 121 test.lim.Dest = server 122 123 header := http.Header{ 124 "X-Websocket-Test-1": []string{"Yes"}, 125 "X-Websocket-Test-2": []string{"Yes"}, 126 "X-Websocket-Test-3": []string{"Yes"}, 127 "X-Websocket-Test-4": []string{"Yes"}, 128 } 129 d := Dialer{ 130 NetDial: func(ctx context.Context, network, addr string) (net.Conn, error) { 131 return connWithWriter{server, test.lim}, nil 132 }, 133 Header: HandshakeHeaderHTTP(header), 134 } 135 var ( 136 expHost = "example.org" 137 expURI = "/path/to/ws" 138 ) 139 receivedHeader := http.Header{} 140 u := Upgrader{ 141 OnRequest: func(uri []byte) error { 142 if u := string(uri); u != expURI { 143 t.Errorf( 144 "unexpected URI in OnRequest() callback: %q; want %q", 145 u, expURI, 146 ) 147 } 148 return nil 149 }, 150 OnHost: func(host []byte) error { 151 if h := string(host); h != expHost { 152 t.Errorf( 153 "unexpected host in OnRequest() callback: %q; want %q", 154 h, expHost, 155 ) 156 } 157 return nil 158 }, 159 OnHeader: func(key, value []byte) error { 160 receivedHeader.Add(string(key), string(value)) 161 return nil 162 }, 163 } 164 upgrade := make(chan error, 1) 165 go func() { 166 _, err := u.Upgrade(client) 167 upgrade <- err 168 }() 169 170 _, _, _, err = d.Dial(context.Background(), "ws://"+expHost+expURI) 171 if err != nil { 172 t.Errorf("Dial() error: %v", err) 173 } 174 175 if err := <-upgrade; err != nil { 176 t.Errorf("Upgrade() error: %v", err) 177 } 178 for key, values := range header { 179 act, has := receivedHeader[key] 180 if !has { 181 t.Errorf("OnHeader() was not called with %q header key", key) 182 } 183 if !reflect.DeepEqual(act, values) { 184 t.Errorf("OnHeader(%q) different values: %v; want %v", key, act, values) 185 } 186 } 187 }) 188 } 189 } 190 191 type connWithWriter struct { 192 net.Conn 193 w io.Writer 194 } 195 196 func (w connWithWriter) Write(p []byte) (int, error) { 197 return w.w.Write(p) 198 } 199 200 type limitWriter struct { 201 Dest io.Writer 202 Bandwidth int 203 Burst int 204 Period time.Duration 205 206 mu sync.Mutex 207 cond sync.Cond 208 once sync.Once 209 done chan struct{} 210 tickets int 211 } 212 213 func (w *limitWriter) init() { 214 w.once.Do(func() { 215 w.cond.L = &w.mu 216 w.done = make(chan struct{}) 217 218 tick := w.Period / time.Duration(w.Bandwidth) 219 go func() { 220 t := time.NewTicker(tick) 221 for { 222 select { 223 case <-t.C: 224 w.mu.Lock() 225 w.tickets = w.Burst 226 w.mu.Unlock() 227 w.cond.Signal() 228 case <-w.done: 229 t.Stop() 230 return 231 } 232 } 233 }() 234 }) 235 } 236 237 func (w *limitWriter) allow(n int) (allowed int) { 238 w.init() 239 w.mu.Lock() 240 defer w.mu.Unlock() 241 for w.tickets == 0 { 242 w.cond.Wait() 243 } 244 if w.tickets < 0 { 245 return -1 246 } 247 allowed = min(w.tickets, n) 248 w.tickets -= allowed 249 return allowed 250 } 251 252 func (w *limitWriter) Close() error { 253 w.init() 254 w.mu.Lock() 255 defer w.mu.Unlock() 256 if w.tickets < 0 { 257 return nil 258 } 259 w.tickets = -1 260 close(w.done) 261 w.cond.Broadcast() 262 return nil 263 } 264 265 func (w *limitWriter) Write(p []byte) (n int, err error) { 266 w.init() 267 for n < len(p) { 268 m := w.allow(len(p)) 269 if m < 0 { 270 return 0, io.ErrClosedPipe 271 } 272 if _, err := w.Dest.Write(p[n : n+m]); err != nil { 273 return n, err 274 } 275 n += m 276 } 277 return n, nil 278 } 279 280 func socketPair() (client, server net.Conn, err error) { 281 ln, err := net.Listen("tcp", "localhost:") 282 if err != nil { 283 return nil, nil, err 284 } 285 type connAndError struct { 286 conn net.Conn 287 err error 288 } 289 dial := make(chan connAndError, 1) 290 go func() { 291 conn, err := net.Dial("tcp", ln.Addr().String()) 292 dial <- connAndError{conn, err} 293 }() 294 server, err = ln.Accept() 295 if err != nil { 296 return nil, nil, err 297 } 298 ce := <-dial 299 if err := ce.err; err != nil { 300 return nil, nil, err 301 } 302 return ce.conn, server, nil 303 } 304 305 func TestHasToken(t *testing.T) { 306 for i, test := range []struct { 307 header string 308 token string 309 exp bool 310 }{ 311 {"Keep-Alive, Close, Upgrade", "upgrade", true}, 312 {"Keep-Alive, Close, upgrade, hello", "upgrade", true}, 313 {"Keep-Alive, Close, hello", "upgrade", false}, 314 } { 315 t.Run(fmt.Sprintf("#%d", i), func(t *testing.T) { 316 if has := strHasToken(test.header, test.token); has != test.exp { 317 t.Errorf("hasToken(%q, %q) = %v; want %v", test.header, test.token, has, test.exp) 318 } 319 }) 320 } 321 } 322 323 func BenchmarkHasToken(b *testing.B) { 324 for i, bench := range []struct { 325 header string 326 token string 327 }{ 328 {"Keep-Alive, Close, Upgrade", "upgrade"}, 329 {"Keep-Alive, Close, upgrade, hello", "upgrade"}, 330 {"Keep-Alive, Close, hello", "upgrade"}, 331 } { 332 b.Run(fmt.Sprintf("#%d", i), func(b *testing.B) { 333 for i := 0; i < b.N; i++ { 334 _ = strHasToken(bench.header, bench.token) 335 } 336 }) 337 } 338 } 339 340 type equalFoldCase struct { 341 label string 342 a, b string 343 } 344 345 var equalFoldCases = []equalFoldCase{ 346 {"websocket", "WebSocket", "websocket"}, 347 {"upgrade", "Upgrade", "upgrade"}, 348 randomEqualLetters(512), 349 inequalAt(randomEqualLetters(512), 256), 350 } 351 352 func TestAsciiToInt(t *testing.T) { 353 for _, test := range []struct { 354 bts []byte 355 exp int 356 err bool 357 }{ 358 {[]byte{'0'}, 0, false}, 359 {[]byte{'1'}, 1, false}, 360 {[]byte("42"), 42, false}, 361 {[]byte("420"), 420, false}, 362 {[]byte("010050042"), 10050042, false}, 363 } { 364 t.Run(fmt.Sprintf("%s", string(test.bts)), func(t *testing.T) { 365 act, err := asciiToInt(test.bts) 366 if (test.err && err == nil) || (!test.err && err != nil) { 367 t.Errorf("unexpected error: %v", err) 368 } 369 if act != test.exp { 370 t.Errorf("asciiToInt(%v) = %v; want %v", test.bts, act, test.exp) 371 } 372 }) 373 } 374 } 375 376 func TestBtrim(t *testing.T) { 377 for _, test := range []struct { 378 bts []byte 379 exp []byte 380 }{ 381 {[]byte("abc"), []byte("abc")}, 382 {[]byte(" abc"), []byte("abc")}, 383 {[]byte("abc "), []byte("abc")}, 384 {[]byte(" abc "), []byte("abc")}, 385 } { 386 t.Run(fmt.Sprintf("%s", string(test.bts)), func(t *testing.T) { 387 if act := btrim(test.bts); !bytes.Equal(act, test.exp) { 388 t.Errorf("btrim(%v) = %v; want %v", test.bts, act, test.exp) 389 } 390 }) 391 } 392 } 393 394 func TestBSplit3(t *testing.T) { 395 for _, test := range []struct { 396 bts []byte 397 sep byte 398 exp1 []byte 399 exp2 []byte 400 exp3 []byte 401 }{ 402 {[]byte(""), ' ', []byte{}, nil, nil}, 403 {[]byte("GET / HTTP/1.1"), ' ', []byte("GET"), []byte("/"), []byte("HTTP/1.1")}, 404 } { 405 t.Run(fmt.Sprintf("%s", string(test.bts)), func(t *testing.T) { 406 b1, b2, b3 := bsplit3(test.bts, test.sep) 407 if !bytes.Equal(b1, test.exp1) || !bytes.Equal(b2, test.exp2) || !bytes.Equal(b3, test.exp3) { 408 t.Errorf( 409 "bsplit3(%q) = %q, %q, %q; want %q, %q, %q", 410 string(test.bts), string(b1), string(b2), string(b3), 411 string(test.exp1), string(test.exp2), string(test.exp3), 412 ) 413 } 414 }) 415 } 416 } 417 418 var canonicalHeaderCases = [][]byte{ 419 []byte("foo-"), 420 []byte("-foo"), 421 []byte("-"), 422 []byte("foo----bar"), 423 []byte("foo-bar"), 424 []byte("FoO-BaR"), 425 []byte("Foo-Bar"), 426 []byte("sec-websocket-extensions"), 427 } 428 429 func TestCanonicalizeHeaderKey(t *testing.T) { 430 for _, bts := range canonicalHeaderCases { 431 t.Run(fmt.Sprintf("%s", string(bts)), func(t *testing.T) { 432 act := append([]byte(nil), bts...) 433 canonicalizeHeaderKey(act) 434 435 exp := strToBytes(textproto.CanonicalMIMEHeaderKey(string(bts))) 436 437 if !bytes.Equal(act, exp) { 438 t.Errorf( 439 "canonicalizeHeaderKey(%v) = %v; want %v", 440 string(bts), string(act), string(exp), 441 ) 442 } 443 }) 444 } 445 } 446 447 func BenchmarkCanonicalizeHeaderKey(b *testing.B) { 448 for _, bts := range canonicalHeaderCases { 449 b.Run(fmt.Sprintf("%s", string(bts)), func(b *testing.B) { 450 for i := 0; i < b.N; i++ { 451 canonicalizeHeaderKey(bts) 452 } 453 }) 454 } 455 } 456 457 func randomEqualLetters(n int) (c equalFoldCase) { 458 c.label = fmt.Sprintf("rnd_eq_%d", n) 459 460 a, b := make([]byte, n), make([]byte, n) 461 462 for i := 0; i < n; i++ { 463 c := byte(rand.Intn('Z'-'A'+1) + 'A') // Random character from 'A' to 'Z'. 464 a[i] = c 465 b[i] = c | ('a' - 'A') // Swap fold. 466 } 467 468 c.a = string(a) 469 c.b = string(b) 470 471 return 472 } 473 474 func inequalAt(c equalFoldCase, i int) equalFoldCase { 475 bts := make([]byte, len(c.a)) 476 copy(bts, c.a) 477 for { 478 b := byte(rand.Intn('z'-'a'+1) + 'a') 479 if bts[i] != b { 480 bts[i] = b 481 c.a = string(bts) 482 c.label = fmt.Sprintf("rnd_ineq_%d_%d", len(c.a), i) 483 return c 484 } 485 } 486 }