gitee.com/ks-custle/core-gm@v0.0.0-20230922171213-b83bdd97b62c/net/websocket/websocket_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "fmt" 11 "io" 12 "log" 13 "net" 14 "net/url" 15 "reflect" 16 "runtime" 17 "strings" 18 "sync" 19 "testing" 20 "time" 21 22 "gitee.com/ks-custle/core-gm/gmhttp/httptest" 23 24 http "gitee.com/ks-custle/core-gm/gmhttp" 25 ) 26 27 var serverAddr string 28 var once sync.Once 29 30 func echoServer(ws *Conn) { 31 defer ws.Close() 32 io.Copy(ws, ws) 33 } 34 35 type Count struct { 36 S string 37 N int 38 } 39 40 func countServer(ws *Conn) { 41 defer ws.Close() 42 for { 43 var count Count 44 err := JSON.Receive(ws, &count) 45 if err != nil { 46 return 47 } 48 count.N++ 49 count.S = strings.Repeat(count.S, count.N) 50 err = JSON.Send(ws, count) 51 if err != nil { 52 return 53 } 54 } 55 } 56 57 type testCtrlAndDataHandler struct { 58 hybiFrameHandler 59 } 60 61 func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) { 62 h.hybiFrameHandler.conn.wio.Lock() 63 defer h.hybiFrameHandler.conn.wio.Unlock() 64 w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame) 65 if err != nil { 66 return 0, err 67 } 68 n, err := w.Write(b) 69 w.Close() 70 return n, err 71 } 72 73 func ctrlAndDataServer(ws *Conn) { 74 defer ws.Close() 75 h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}} 76 ws.frameHandler = h 77 78 go func() { 79 for i := 0; ; i++ { 80 var b []byte 81 if i%2 != 0 { // with or without payload 82 b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i)) 83 } 84 if _, err := h.WritePing(b); err != nil { 85 break 86 } 87 if _, err := h.WritePong(b); err != nil { // unsolicited pong 88 break 89 } 90 time.Sleep(10 * time.Millisecond) 91 } 92 }() 93 94 b := make([]byte, 128) 95 for { 96 n, err := ws.Read(b) 97 if err != nil { 98 break 99 } 100 if _, err := ws.Write(b[:n]); err != nil { 101 break 102 } 103 } 104 } 105 106 func subProtocolHandshake(config *Config, req *http.Request) error { 107 for _, proto := range config.Protocol { 108 if proto == "chat" { 109 config.Protocol = []string{proto} 110 return nil 111 } 112 } 113 return ErrBadWebSocketProtocol 114 } 115 116 func subProtoServer(ws *Conn) { 117 for _, proto := range ws.Config().Protocol { 118 io.WriteString(ws, proto) 119 } 120 } 121 122 func startServer() { 123 http.Handle("/echo", Handler(echoServer)) 124 http.Handle("/count", Handler(countServer)) 125 http.Handle("/ctrldata", Handler(ctrlAndDataServer)) 126 subproto := Server{ 127 Handshake: subProtocolHandshake, 128 Handler: Handler(subProtoServer), 129 } 130 http.Handle("/subproto", subproto) 131 server := httptest.NewServer(nil) 132 serverAddr = server.Listener.Addr().String() 133 log.Print("Test WebSocket server listening on ", serverAddr) 134 } 135 136 func newConfig(t *testing.T, path string) *Config { 137 config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost") 138 return config 139 } 140 141 func TestEcho(t *testing.T) { 142 once.Do(startServer) 143 144 // websocket.Dial() 145 client, err := net.Dial("tcp", serverAddr) 146 if err != nil { 147 t.Fatal("dialing", err) 148 } 149 conn, err := NewClient(newConfig(t, "/echo"), client) 150 if err != nil { 151 t.Errorf("WebSocket handshake error: %v", err) 152 return 153 } 154 155 msg := []byte("hello, world\n") 156 if _, err := conn.Write(msg); err != nil { 157 t.Errorf("Write: %v", err) 158 } 159 var actual_msg = make([]byte, 512) 160 n, err := conn.Read(actual_msg) 161 if err != nil { 162 t.Errorf("Read: %v", err) 163 } 164 actual_msg = actual_msg[0:n] 165 if !bytes.Equal(msg, actual_msg) { 166 t.Errorf("Echo: expected %q got %q", msg, actual_msg) 167 } 168 conn.Close() 169 } 170 171 func TestAddr(t *testing.T) { 172 once.Do(startServer) 173 174 // websocket.Dial() 175 client, err := net.Dial("tcp", serverAddr) 176 if err != nil { 177 t.Fatal("dialing", err) 178 } 179 conn, err := NewClient(newConfig(t, "/echo"), client) 180 if err != nil { 181 t.Errorf("WebSocket handshake error: %v", err) 182 return 183 } 184 185 ra := conn.RemoteAddr().String() 186 if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") { 187 t.Errorf("Bad remote addr: %v", ra) 188 } 189 la := conn.LocalAddr().String() 190 if !strings.HasPrefix(la, "http://") { 191 t.Errorf("Bad local addr: %v", la) 192 } 193 conn.Close() 194 } 195 196 func TestCount(t *testing.T) { 197 once.Do(startServer) 198 199 // websocket.Dial() 200 client, err := net.Dial("tcp", serverAddr) 201 if err != nil { 202 t.Fatal("dialing", err) 203 } 204 conn, err := NewClient(newConfig(t, "/count"), client) 205 if err != nil { 206 t.Errorf("WebSocket handshake error: %v", err) 207 return 208 } 209 210 var count Count 211 count.S = "hello" 212 if err := JSON.Send(conn, count); err != nil { 213 t.Errorf("Write: %v", err) 214 } 215 if err := JSON.Receive(conn, &count); err != nil { 216 t.Errorf("Read: %v", err) 217 } 218 if count.N != 1 { 219 t.Errorf("count: expected %d got %d", 1, count.N) 220 } 221 if count.S != "hello" { 222 t.Errorf("count: expected %q got %q", "hello", count.S) 223 } 224 if err := JSON.Send(conn, count); err != nil { 225 t.Errorf("Write: %v", err) 226 } 227 if err := JSON.Receive(conn, &count); err != nil { 228 t.Errorf("Read: %v", err) 229 } 230 if count.N != 2 { 231 t.Errorf("count: expected %d got %d", 2, count.N) 232 } 233 if count.S != "hellohello" { 234 t.Errorf("count: expected %q got %q", "hellohello", count.S) 235 } 236 conn.Close() 237 } 238 239 func TestWithQuery(t *testing.T) { 240 once.Do(startServer) 241 242 client, err := net.Dial("tcp", serverAddr) 243 if err != nil { 244 t.Fatal("dialing", err) 245 } 246 247 config := newConfig(t, "/echo") 248 config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr)) 249 if err != nil { 250 t.Fatal("location url", err) 251 } 252 253 ws, err := NewClient(config, client) 254 if err != nil { 255 t.Errorf("WebSocket handshake: %v", err) 256 return 257 } 258 ws.Close() 259 } 260 261 func testWithProtocol(t *testing.T, subproto []string) (string, error) { 262 once.Do(startServer) 263 264 client, err := net.Dial("tcp", serverAddr) 265 if err != nil { 266 t.Fatal("dialing", err) 267 } 268 269 config := newConfig(t, "/subproto") 270 config.Protocol = subproto 271 272 ws, err := NewClient(config, client) 273 if err != nil { 274 return "", err 275 } 276 msg := make([]byte, 16) 277 n, err := ws.Read(msg) 278 if err != nil { 279 return "", err 280 } 281 ws.Close() 282 return string(msg[:n]), nil 283 } 284 285 func TestWithProtocol(t *testing.T) { 286 proto, err := testWithProtocol(t, []string{"chat"}) 287 if err != nil { 288 t.Errorf("SubProto: unexpected error: %v", err) 289 } 290 if proto != "chat" { 291 t.Errorf("SubProto: expected %q, got %q", "chat", proto) 292 } 293 } 294 295 func TestWithTwoProtocol(t *testing.T) { 296 proto, err := testWithProtocol(t, []string{"test", "chat"}) 297 if err != nil { 298 t.Errorf("SubProto: unexpected error: %v", err) 299 } 300 if proto != "chat" { 301 t.Errorf("SubProto: expected %q, got %q", "chat", proto) 302 } 303 } 304 305 func TestWithBadProtocol(t *testing.T) { 306 _, err := testWithProtocol(t, []string{"test"}) 307 if err != ErrBadStatus { 308 t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err) 309 } 310 } 311 312 func TestHTTP(t *testing.T) { 313 once.Do(startServer) 314 315 // If the client did not send a handshake that matches the protocol 316 // specification, the server MUST return an HTTP response with an 317 // appropriate error code (such as 400 Bad Request) 318 resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr)) 319 if err != nil { 320 t.Errorf("Get: error %#v", err) 321 return 322 } 323 if resp == nil { 324 t.Error("Get: resp is null") 325 return 326 } 327 if resp.StatusCode != http.StatusBadRequest { 328 t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode) 329 } 330 } 331 332 func TestTrailingSpaces(t *testing.T) { 333 // http://code.google.com/p/go/issues/detail?id=955 334 // The last runs of this create keys with trailing spaces that should not be 335 // generated by the client. 336 once.Do(startServer) 337 config := newConfig(t, "/echo") 338 for i := 0; i < 30; i++ { 339 // body 340 ws, err := DialConfig(config) 341 if err != nil { 342 t.Errorf("Dial #%d failed: %v", i, err) 343 break 344 } 345 ws.Close() 346 } 347 } 348 349 func TestDialConfigBadVersion(t *testing.T) { 350 once.Do(startServer) 351 config := newConfig(t, "/echo") 352 config.Version = 1234 353 354 _, err := DialConfig(config) 355 356 if dialerr, ok := err.(*DialError); ok { 357 if dialerr.Err != ErrBadProtocolVersion { 358 t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err) 359 } 360 } 361 } 362 363 func TestDialConfigWithDialer(t *testing.T) { 364 once.Do(startServer) 365 config := newConfig(t, "/echo") 366 config.Dialer = &net.Dialer{ 367 Deadline: time.Now().Add(-time.Minute), 368 } 369 _, err := DialConfig(config) 370 dialerr, ok := err.(*DialError) 371 if !ok { 372 t.Fatalf("DialError expected, got %#v", err) 373 } 374 neterr, ok := dialerr.Err.(*net.OpError) 375 if !ok { 376 t.Fatalf("net.OpError error expected, got %#v", dialerr.Err) 377 } 378 if !neterr.Timeout() { 379 t.Fatalf("expected timeout error, got %#v", neterr) 380 } 381 } 382 383 func TestSmallBuffer(t *testing.T) { 384 // http://code.google.com/p/go/issues/detail?id=1145 385 // Read should be able to handle reading a fragment of a frame. 386 once.Do(startServer) 387 388 // websocket.Dial() 389 client, err := net.Dial("tcp", serverAddr) 390 if err != nil { 391 t.Fatal("dialing", err) 392 } 393 conn, err := NewClient(newConfig(t, "/echo"), client) 394 if err != nil { 395 t.Errorf("WebSocket handshake error: %v", err) 396 return 397 } 398 399 msg := []byte("hello, world\n") 400 if _, err := conn.Write(msg); err != nil { 401 t.Errorf("Write: %v", err) 402 } 403 var small_msg = make([]byte, 8) 404 n, err := conn.Read(small_msg) 405 if err != nil { 406 t.Errorf("Read: %v", err) 407 } 408 if !bytes.Equal(msg[:len(small_msg)], small_msg) { 409 t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg) 410 } 411 var second_msg = make([]byte, len(msg)) 412 n, err = conn.Read(second_msg) 413 if err != nil { 414 t.Errorf("Read: %v", err) 415 } 416 second_msg = second_msg[0:n] 417 if !bytes.Equal(msg[len(small_msg):], second_msg) { 418 t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg) 419 } 420 conn.Close() 421 } 422 423 var parseAuthorityTests = []struct { 424 in *url.URL 425 out string 426 }{ 427 { 428 &url.URL{ 429 Scheme: "ws", 430 Host: "www.google.com", 431 }, 432 "www.google.com:80", 433 }, 434 { 435 &url.URL{ 436 Scheme: "wss", 437 Host: "www.google.com", 438 }, 439 "www.google.com:443", 440 }, 441 { 442 &url.URL{ 443 Scheme: "ws", 444 Host: "www.google.com:80", 445 }, 446 "www.google.com:80", 447 }, 448 { 449 &url.URL{ 450 Scheme: "wss", 451 Host: "www.google.com:443", 452 }, 453 "www.google.com:443", 454 }, 455 // some invalid ones for parseAuthority. parseAuthority doesn't 456 // concern itself with the scheme unless it actually knows about it 457 { 458 &url.URL{ 459 Scheme: "http", 460 Host: "www.google.com", 461 }, 462 "www.google.com", 463 }, 464 { 465 &url.URL{ 466 Scheme: "http", 467 Host: "www.google.com:80", 468 }, 469 "www.google.com:80", 470 }, 471 { 472 &url.URL{ 473 Scheme: "asdf", 474 Host: "127.0.0.1", 475 }, 476 "127.0.0.1", 477 }, 478 { 479 &url.URL{ 480 Scheme: "asdf", 481 Host: "www.google.com", 482 }, 483 "www.google.com", 484 }, 485 } 486 487 func TestParseAuthority(t *testing.T) { 488 for _, tt := range parseAuthorityTests { 489 out := parseAuthority(tt.in) 490 if out != tt.out { 491 t.Errorf("got %v; want %v", out, tt.out) 492 } 493 } 494 } 495 496 type closerConn struct { 497 net.Conn 498 closed int // count of the number of times Close was called 499 } 500 501 func (c *closerConn) Close() error { 502 c.closed++ 503 return c.Conn.Close() 504 } 505 506 func TestClose(t *testing.T) { 507 if runtime.GOOS == "plan9" { 508 t.Skip("see golang.org/issue/11454") 509 } 510 511 once.Do(startServer) 512 513 conn, err := net.Dial("tcp", serverAddr) 514 if err != nil { 515 t.Fatal("dialing", err) 516 } 517 518 cc := closerConn{Conn: conn} 519 520 client, err := NewClient(newConfig(t, "/echo"), &cc) 521 if err != nil { 522 t.Fatalf("WebSocket handshake: %v", err) 523 } 524 525 // set the deadline to ten minutes ago, which will have expired by the time 526 // client.Close sends the close status frame. 527 conn.SetDeadline(time.Now().Add(-10 * time.Minute)) 528 529 if err := client.Close(); err == nil { 530 t.Errorf("ws.Close(): expected error, got %v", err) 531 } 532 if cc.closed < 1 { 533 t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed) 534 } 535 } 536 537 var originTests = []struct { 538 req *http.Request 539 origin *url.URL 540 }{ 541 { 542 req: &http.Request{ 543 Header: http.Header{ 544 "Origin": []string{"http://www.example.com"}, 545 }, 546 }, 547 origin: &url.URL{ 548 Scheme: "http", 549 Host: "www.example.com", 550 }, 551 }, 552 { 553 req: &http.Request{}, 554 }, 555 } 556 557 func TestOrigin(t *testing.T) { 558 conf := newConfig(t, "/echo") 559 conf.Version = ProtocolVersionHybi13 560 for i, tt := range originTests { 561 origin, err := Origin(conf, tt.req) 562 if err != nil { 563 t.Error(err) 564 continue 565 } 566 if !reflect.DeepEqual(origin, tt.origin) { 567 t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin) 568 continue 569 } 570 } 571 } 572 573 func TestCtrlAndData(t *testing.T) { 574 once.Do(startServer) 575 576 c, err := net.Dial("tcp", serverAddr) 577 if err != nil { 578 t.Fatal(err) 579 } 580 ws, err := NewClient(newConfig(t, "/ctrldata"), c) 581 if err != nil { 582 t.Fatal(err) 583 } 584 defer ws.Close() 585 586 h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}} 587 ws.frameHandler = h 588 589 b := make([]byte, 128) 590 for i := 0; i < 2; i++ { 591 data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i)) 592 if _, err := ws.Write(data); err != nil { 593 t.Fatalf("#%d: %v", i, err) 594 } 595 var ctrl []byte 596 if i%2 != 0 { // with or without payload 597 ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i)) 598 } 599 if _, err := h.WritePing(ctrl); err != nil { 600 t.Fatalf("#%d: %v", i, err) 601 } 602 n, err := ws.Read(b) 603 if err != nil { 604 t.Fatalf("#%d: %v", i, err) 605 } 606 if !bytes.Equal(b[:n], data) { 607 t.Fatalf("#%d: got %v; want %v", i, b[:n], data) 608 } 609 } 610 } 611 612 func TestCodec_ReceiveLimited(t *testing.T) { 613 const limit = 2048 614 var payloads [][]byte 615 for _, size := range []int{ 616 1024, 617 2048, 618 4096, // receive of this message would be interrupted due to limit 619 2048, // this one is to make sure next receive recovers discarding leftovers 620 } { 621 b := make([]byte, size) 622 rand.Read(b) 623 payloads = append(payloads, b) 624 } 625 handlerDone := make(chan struct{}) 626 limitedHandler := func(ws *Conn) { 627 defer close(handlerDone) 628 ws.MaxPayloadBytes = limit 629 defer ws.Close() 630 for i, p := range payloads { 631 t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit) 632 var recv []byte 633 err := Message.Receive(ws, &recv) 634 switch err { 635 case nil: 636 case ErrFrameTooLarge: 637 if len(p) <= limit { 638 t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit) 639 } 640 continue 641 default: 642 t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err) 643 } 644 if len(recv) > limit { 645 t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit) 646 } 647 if !bytes.Equal(p, recv) { 648 t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p) 649 } 650 } 651 } 652 server := httptest.NewServer(Handler(limitedHandler)) 653 defer server.CloseClientConnections() 654 defer server.Close() 655 addr := server.Listener.Addr().String() 656 ws, err := Dial("ws://"+addr+"/", "", "http://localhost/") 657 if err != nil { 658 t.Fatal(err) 659 } 660 defer ws.Close() 661 for i, p := range payloads { 662 if err := Message.Send(ws, p); err != nil { 663 t.Fatalf("payload #%d (size %d): %v", i, len(p), err) 664 } 665 } 666 <-handlerDone 667 }