github.com/Andyfoo/golang/x/net@v0.0.0-20190901054642-57c1bf301704/websocket/websocket_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 import ( 8 "bytes" 9 "crypto/rand" 10 "fmt" 11 "io" 12 "log" 13 "net" 14 "net/http" 15 "net/http/httptest" 16 "net/url" 17 "reflect" 18 "runtime" 19 "strings" 20 "sync" 21 "testing" 22 "time" 23 ) 24 25 var serverAddr string 26 var once sync.Once 27 28 func echoServer(ws *Conn) { 29 defer ws.Close() 30 io.Copy(ws, ws) 31 } 32 33 type Count struct { 34 S string 35 N int 36 } 37 38 func countServer(ws *Conn) { 39 defer ws.Close() 40 for { 41 var count Count 42 err := JSON.Receive(ws, &count) 43 if err != nil { 44 return 45 } 46 count.N++ 47 count.S = strings.Repeat(count.S, count.N) 48 err = JSON.Send(ws, count) 49 if err != nil { 50 return 51 } 52 } 53 } 54 55 type testCtrlAndDataHandler struct { 56 hybiFrameHandler 57 } 58 59 func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) { 60 h.hybiFrameHandler.conn.wio.Lock() 61 defer h.hybiFrameHandler.conn.wio.Unlock() 62 w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame) 63 if err != nil { 64 return 0, err 65 } 66 n, err := w.Write(b) 67 w.Close() 68 return n, err 69 } 70 71 func ctrlAndDataServer(ws *Conn) { 72 defer ws.Close() 73 h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}} 74 ws.frameHandler = h 75 76 go func() { 77 for i := 0; ; i++ { 78 var b []byte 79 if i%2 != 0 { // with or without payload 80 b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i)) 81 } 82 if _, err := h.WritePing(b); err != nil { 83 break 84 } 85 if _, err := h.WritePong(b); err != nil { // unsolicited pong 86 break 87 } 88 time.Sleep(10 * time.Millisecond) 89 } 90 }() 91 92 b := make([]byte, 128) 93 for { 94 n, err := ws.Read(b) 95 if err != nil { 96 break 97 } 98 if _, err := ws.Write(b[:n]); err != nil { 99 break 100 } 101 } 102 } 103 104 func subProtocolHandshake(config *Config, req *http.Request) error { 105 for _, proto := range config.Protocol { 106 if proto == "chat" { 107 config.Protocol = []string{proto} 108 return nil 109 } 110 } 111 return ErrBadWebSocketProtocol 112 } 113 114 func subProtoServer(ws *Conn) { 115 for _, proto := range ws.Config().Protocol { 116 io.WriteString(ws, proto) 117 } 118 } 119 120 func startServer() { 121 http.Handle("/echo", Handler(echoServer)) 122 http.Handle("/count", Handler(countServer)) 123 http.Handle("/ctrldata", Handler(ctrlAndDataServer)) 124 subproto := Server{ 125 Handshake: subProtocolHandshake, 126 Handler: Handler(subProtoServer), 127 } 128 http.Handle("/subproto", subproto) 129 server := httptest.NewServer(nil) 130 serverAddr = server.Listener.Addr().String() 131 log.Print("Test WebSocket server listening on ", serverAddr) 132 } 133 134 func newConfig(t *testing.T, path string) *Config { 135 config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost") 136 return config 137 } 138 139 func TestEcho(t *testing.T) { 140 once.Do(startServer) 141 142 // websocket.Dial() 143 client, err := net.Dial("tcp", serverAddr) 144 if err != nil { 145 t.Fatal("dialing", err) 146 } 147 conn, err := NewClient(newConfig(t, "/echo"), client) 148 if err != nil { 149 t.Errorf("WebSocket handshake error: %v", err) 150 return 151 } 152 153 msg := []byte("hello, world\n") 154 if _, err := conn.Write(msg); err != nil { 155 t.Errorf("Write: %v", err) 156 } 157 var actual_msg = make([]byte, 512) 158 n, err := conn.Read(actual_msg) 159 if err != nil { 160 t.Errorf("Read: %v", err) 161 } 162 actual_msg = actual_msg[0:n] 163 if !bytes.Equal(msg, actual_msg) { 164 t.Errorf("Echo: expected %q got %q", msg, actual_msg) 165 } 166 conn.Close() 167 } 168 169 func TestAddr(t *testing.T) { 170 once.Do(startServer) 171 172 // websocket.Dial() 173 client, err := net.Dial("tcp", serverAddr) 174 if err != nil { 175 t.Fatal("dialing", err) 176 } 177 conn, err := NewClient(newConfig(t, "/echo"), client) 178 if err != nil { 179 t.Errorf("WebSocket handshake error: %v", err) 180 return 181 } 182 183 ra := conn.RemoteAddr().String() 184 if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") { 185 t.Errorf("Bad remote addr: %v", ra) 186 } 187 la := conn.LocalAddr().String() 188 if !strings.HasPrefix(la, "http://") { 189 t.Errorf("Bad local addr: %v", la) 190 } 191 conn.Close() 192 } 193 194 func TestCount(t *testing.T) { 195 once.Do(startServer) 196 197 // websocket.Dial() 198 client, err := net.Dial("tcp", serverAddr) 199 if err != nil { 200 t.Fatal("dialing", err) 201 } 202 conn, err := NewClient(newConfig(t, "/count"), client) 203 if err != nil { 204 t.Errorf("WebSocket handshake error: %v", err) 205 return 206 } 207 208 var count Count 209 count.S = "hello" 210 if err := JSON.Send(conn, count); err != nil { 211 t.Errorf("Write: %v", err) 212 } 213 if err := JSON.Receive(conn, &count); err != nil { 214 t.Errorf("Read: %v", err) 215 } 216 if count.N != 1 { 217 t.Errorf("count: expected %d got %d", 1, count.N) 218 } 219 if count.S != "hello" { 220 t.Errorf("count: expected %q got %q", "hello", count.S) 221 } 222 if err := JSON.Send(conn, count); err != nil { 223 t.Errorf("Write: %v", err) 224 } 225 if err := JSON.Receive(conn, &count); err != nil { 226 t.Errorf("Read: %v", err) 227 } 228 if count.N != 2 { 229 t.Errorf("count: expected %d got %d", 2, count.N) 230 } 231 if count.S != "hellohello" { 232 t.Errorf("count: expected %q got %q", "hellohello", count.S) 233 } 234 conn.Close() 235 } 236 237 func TestWithQuery(t *testing.T) { 238 once.Do(startServer) 239 240 client, err := net.Dial("tcp", serverAddr) 241 if err != nil { 242 t.Fatal("dialing", err) 243 } 244 245 config := newConfig(t, "/echo") 246 config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr)) 247 if err != nil { 248 t.Fatal("location url", err) 249 } 250 251 ws, err := NewClient(config, client) 252 if err != nil { 253 t.Errorf("WebSocket handshake: %v", err) 254 return 255 } 256 ws.Close() 257 } 258 259 func testWithProtocol(t *testing.T, subproto []string) (string, error) { 260 once.Do(startServer) 261 262 client, err := net.Dial("tcp", serverAddr) 263 if err != nil { 264 t.Fatal("dialing", err) 265 } 266 267 config := newConfig(t, "/subproto") 268 config.Protocol = subproto 269 270 ws, err := NewClient(config, client) 271 if err != nil { 272 return "", err 273 } 274 msg := make([]byte, 16) 275 n, err := ws.Read(msg) 276 if err != nil { 277 return "", err 278 } 279 ws.Close() 280 return string(msg[:n]), nil 281 } 282 283 func TestWithProtocol(t *testing.T) { 284 proto, err := testWithProtocol(t, []string{"chat"}) 285 if err != nil { 286 t.Errorf("SubProto: unexpected error: %v", err) 287 } 288 if proto != "chat" { 289 t.Errorf("SubProto: expected %q, got %q", "chat", proto) 290 } 291 } 292 293 func TestWithTwoProtocol(t *testing.T) { 294 proto, err := testWithProtocol(t, []string{"test", "chat"}) 295 if err != nil { 296 t.Errorf("SubProto: unexpected error: %v", err) 297 } 298 if proto != "chat" { 299 t.Errorf("SubProto: expected %q, got %q", "chat", proto) 300 } 301 } 302 303 func TestWithBadProtocol(t *testing.T) { 304 _, err := testWithProtocol(t, []string{"test"}) 305 if err != ErrBadStatus { 306 t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err) 307 } 308 } 309 310 func TestHTTP(t *testing.T) { 311 once.Do(startServer) 312 313 // If the client did not send a handshake that matches the protocol 314 // specification, the server MUST return an HTTP response with an 315 // appropriate error code (such as 400 Bad Request) 316 resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr)) 317 if err != nil { 318 t.Errorf("Get: error %#v", err) 319 return 320 } 321 if resp == nil { 322 t.Error("Get: resp is null") 323 return 324 } 325 if resp.StatusCode != http.StatusBadRequest { 326 t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode) 327 } 328 } 329 330 func TestTrailingSpaces(t *testing.T) { 331 // http://code.google.com/p/go/issues/detail?id=955 332 // The last runs of this create keys with trailing spaces that should not be 333 // generated by the client. 334 once.Do(startServer) 335 config := newConfig(t, "/echo") 336 for i := 0; i < 30; i++ { 337 // body 338 ws, err := DialConfig(config) 339 if err != nil { 340 t.Errorf("Dial #%d failed: %v", i, err) 341 break 342 } 343 ws.Close() 344 } 345 } 346 347 func TestDialConfigBadVersion(t *testing.T) { 348 once.Do(startServer) 349 config := newConfig(t, "/echo") 350 config.Version = 1234 351 352 _, err := DialConfig(config) 353 354 if dialerr, ok := err.(*DialError); ok { 355 if dialerr.Err != ErrBadProtocolVersion { 356 t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err) 357 } 358 } 359 } 360 361 func TestDialConfigWithDialer(t *testing.T) { 362 once.Do(startServer) 363 config := newConfig(t, "/echo") 364 config.Dialer = &net.Dialer{ 365 Deadline: time.Now().Add(-time.Minute), 366 } 367 _, err := DialConfig(config) 368 dialerr, ok := err.(*DialError) 369 if !ok { 370 t.Fatalf("DialError expected, got %#v", err) 371 } 372 neterr, ok := dialerr.Err.(*net.OpError) 373 if !ok { 374 t.Fatalf("net.OpError error expected, got %#v", dialerr.Err) 375 } 376 if !neterr.Timeout() { 377 t.Fatalf("expected timeout error, got %#v", neterr) 378 } 379 } 380 381 func TestSmallBuffer(t *testing.T) { 382 // http://code.google.com/p/go/issues/detail?id=1145 383 // Read should be able to handle reading a fragment of a frame. 384 once.Do(startServer) 385 386 // websocket.Dial() 387 client, err := net.Dial("tcp", serverAddr) 388 if err != nil { 389 t.Fatal("dialing", err) 390 } 391 conn, err := NewClient(newConfig(t, "/echo"), client) 392 if err != nil { 393 t.Errorf("WebSocket handshake error: %v", err) 394 return 395 } 396 397 msg := []byte("hello, world\n") 398 if _, err := conn.Write(msg); err != nil { 399 t.Errorf("Write: %v", err) 400 } 401 var small_msg = make([]byte, 8) 402 n, err := conn.Read(small_msg) 403 if err != nil { 404 t.Errorf("Read: %v", err) 405 } 406 if !bytes.Equal(msg[:len(small_msg)], small_msg) { 407 t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg) 408 } 409 var second_msg = make([]byte, len(msg)) 410 n, err = conn.Read(second_msg) 411 if err != nil { 412 t.Errorf("Read: %v", err) 413 } 414 second_msg = second_msg[0:n] 415 if !bytes.Equal(msg[len(small_msg):], second_msg) { 416 t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg) 417 } 418 conn.Close() 419 } 420 421 var parseAuthorityTests = []struct { 422 in *url.URL 423 out string 424 }{ 425 { 426 &url.URL{ 427 Scheme: "ws", 428 Host: "www.google.com", 429 }, 430 "www.google.com:80", 431 }, 432 { 433 &url.URL{ 434 Scheme: "wss", 435 Host: "www.google.com", 436 }, 437 "www.google.com:443", 438 }, 439 { 440 &url.URL{ 441 Scheme: "ws", 442 Host: "www.google.com:80", 443 }, 444 "www.google.com:80", 445 }, 446 { 447 &url.URL{ 448 Scheme: "wss", 449 Host: "www.google.com:443", 450 }, 451 "www.google.com:443", 452 }, 453 // some invalid ones for parseAuthority. parseAuthority doesn't 454 // concern itself with the scheme unless it actually knows about it 455 { 456 &url.URL{ 457 Scheme: "http", 458 Host: "www.google.com", 459 }, 460 "www.google.com", 461 }, 462 { 463 &url.URL{ 464 Scheme: "http", 465 Host: "www.google.com:80", 466 }, 467 "www.google.com:80", 468 }, 469 { 470 &url.URL{ 471 Scheme: "asdf", 472 Host: "127.0.0.1", 473 }, 474 "127.0.0.1", 475 }, 476 { 477 &url.URL{ 478 Scheme: "asdf", 479 Host: "www.google.com", 480 }, 481 "www.google.com", 482 }, 483 } 484 485 func TestParseAuthority(t *testing.T) { 486 for _, tt := range parseAuthorityTests { 487 out := parseAuthority(tt.in) 488 if out != tt.out { 489 t.Errorf("got %v; want %v", out, tt.out) 490 } 491 } 492 } 493 494 type closerConn struct { 495 net.Conn 496 closed int // count of the number of times Close was called 497 } 498 499 func (c *closerConn) Close() error { 500 c.closed++ 501 return c.Conn.Close() 502 } 503 504 func TestClose(t *testing.T) { 505 if runtime.GOOS == "plan9" { 506 t.Skip("see golang.org/issue/11454") 507 } 508 509 once.Do(startServer) 510 511 conn, err := net.Dial("tcp", serverAddr) 512 if err != nil { 513 t.Fatal("dialing", err) 514 } 515 516 cc := closerConn{Conn: conn} 517 518 client, err := NewClient(newConfig(t, "/echo"), &cc) 519 if err != nil { 520 t.Fatalf("WebSocket handshake: %v", err) 521 } 522 523 // set the deadline to ten minutes ago, which will have expired by the time 524 // client.Close sends the close status frame. 525 conn.SetDeadline(time.Now().Add(-10 * time.Minute)) 526 527 if err := client.Close(); err == nil { 528 t.Errorf("ws.Close(): expected error, got %v", err) 529 } 530 if cc.closed < 1 { 531 t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed) 532 } 533 } 534 535 var originTests = []struct { 536 req *http.Request 537 origin *url.URL 538 }{ 539 { 540 req: &http.Request{ 541 Header: http.Header{ 542 "Origin": []string{"http://www.example.com"}, 543 }, 544 }, 545 origin: &url.URL{ 546 Scheme: "http", 547 Host: "www.example.com", 548 }, 549 }, 550 { 551 req: &http.Request{}, 552 }, 553 } 554 555 func TestOrigin(t *testing.T) { 556 conf := newConfig(t, "/echo") 557 conf.Version = ProtocolVersionHybi13 558 for i, tt := range originTests { 559 origin, err := Origin(conf, tt.req) 560 if err != nil { 561 t.Error(err) 562 continue 563 } 564 if !reflect.DeepEqual(origin, tt.origin) { 565 t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin) 566 continue 567 } 568 } 569 } 570 571 func TestCtrlAndData(t *testing.T) { 572 once.Do(startServer) 573 574 c, err := net.Dial("tcp", serverAddr) 575 if err != nil { 576 t.Fatal(err) 577 } 578 ws, err := NewClient(newConfig(t, "/ctrldata"), c) 579 if err != nil { 580 t.Fatal(err) 581 } 582 defer ws.Close() 583 584 h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}} 585 ws.frameHandler = h 586 587 b := make([]byte, 128) 588 for i := 0; i < 2; i++ { 589 data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i)) 590 if _, err := ws.Write(data); err != nil { 591 t.Fatalf("#%d: %v", i, err) 592 } 593 var ctrl []byte 594 if i%2 != 0 { // with or without payload 595 ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i)) 596 } 597 if _, err := h.WritePing(ctrl); err != nil { 598 t.Fatalf("#%d: %v", i, err) 599 } 600 n, err := ws.Read(b) 601 if err != nil { 602 t.Fatalf("#%d: %v", i, err) 603 } 604 if !bytes.Equal(b[:n], data) { 605 t.Fatalf("#%d: got %v; want %v", i, b[:n], data) 606 } 607 } 608 } 609 610 func TestCodec_ReceiveLimited(t *testing.T) { 611 const limit = 2048 612 var payloads [][]byte 613 for _, size := range []int{ 614 1024, 615 2048, 616 4096, // receive of this message would be interrupted due to limit 617 2048, // this one is to make sure next receive recovers discarding leftovers 618 } { 619 b := make([]byte, size) 620 rand.Read(b) 621 payloads = append(payloads, b) 622 } 623 handlerDone := make(chan struct{}) 624 limitedHandler := func(ws *Conn) { 625 defer close(handlerDone) 626 ws.MaxPayloadBytes = limit 627 defer ws.Close() 628 for i, p := range payloads { 629 t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit) 630 var recv []byte 631 err := Message.Receive(ws, &recv) 632 switch err { 633 case nil: 634 case ErrFrameTooLarge: 635 if len(p) <= limit { 636 t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit) 637 } 638 continue 639 default: 640 t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err) 641 } 642 if len(recv) > limit { 643 t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit) 644 } 645 if !bytes.Equal(p, recv) { 646 t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p) 647 } 648 } 649 } 650 server := httptest.NewServer(Handler(limitedHandler)) 651 defer server.CloseClientConnections() 652 defer server.Close() 653 addr := server.Listener.Addr().String() 654 ws, err := Dial("ws://"+addr+"/", "", "http://localhost/") 655 if err != nil { 656 t.Fatal(err) 657 } 658 defer ws.Close() 659 for i, p := range payloads { 660 if err := Message.Send(ws, p); err != nil { 661 t.Fatalf("payload #%d (size %d): %v", i, len(p), err) 662 } 663 } 664 <-handlerDone 665 }