github.com/graybobo/golang.org-package-offline-cache@v0.0.0-20200626051047-6608995c132f/x/net/websocket/websocket_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"net/url"
    16  	"reflect"
    17  	"runtime"
    18  	"strings"
    19  	"sync"
    20  	"testing"
    21  	"time"
    22  )
    23  
    24  var serverAddr string
    25  var once sync.Once
    26  
    27  func echoServer(ws *Conn) {
    28  	defer ws.Close()
    29  	io.Copy(ws, ws)
    30  }
    31  
    32  type Count struct {
    33  	S string
    34  	N int
    35  }
    36  
    37  func countServer(ws *Conn) {
    38  	defer ws.Close()
    39  	for {
    40  		var count Count
    41  		err := JSON.Receive(ws, &count)
    42  		if err != nil {
    43  			return
    44  		}
    45  		count.N++
    46  		count.S = strings.Repeat(count.S, count.N)
    47  		err = JSON.Send(ws, count)
    48  		if err != nil {
    49  			return
    50  		}
    51  	}
    52  }
    53  
    54  type testCtrlAndDataHandler struct {
    55  	hybiFrameHandler
    56  }
    57  
    58  func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
    59  	h.hybiFrameHandler.conn.wio.Lock()
    60  	defer h.hybiFrameHandler.conn.wio.Unlock()
    61  	w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
    62  	if err != nil {
    63  		return 0, err
    64  	}
    65  	n, err := w.Write(b)
    66  	w.Close()
    67  	return n, err
    68  }
    69  
    70  func ctrlAndDataServer(ws *Conn) {
    71  	defer ws.Close()
    72  	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
    73  	ws.frameHandler = h
    74  
    75  	go func() {
    76  		for i := 0; ; i++ {
    77  			var b []byte
    78  			if i%2 != 0 { // with or without payload
    79  				b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
    80  			}
    81  			if _, err := h.WritePing(b); err != nil {
    82  				break
    83  			}
    84  			if _, err := h.WritePong(b); err != nil { // unsolicited pong
    85  				break
    86  			}
    87  			time.Sleep(10 * time.Millisecond)
    88  		}
    89  	}()
    90  
    91  	b := make([]byte, 128)
    92  	for {
    93  		n, err := ws.Read(b)
    94  		if err != nil {
    95  			break
    96  		}
    97  		if _, err := ws.Write(b[:n]); err != nil {
    98  			break
    99  		}
   100  	}
   101  }
   102  
   103  func subProtocolHandshake(config *Config, req *http.Request) error {
   104  	for _, proto := range config.Protocol {
   105  		if proto == "chat" {
   106  			config.Protocol = []string{proto}
   107  			return nil
   108  		}
   109  	}
   110  	return ErrBadWebSocketProtocol
   111  }
   112  
   113  func subProtoServer(ws *Conn) {
   114  	for _, proto := range ws.Config().Protocol {
   115  		io.WriteString(ws, proto)
   116  	}
   117  }
   118  
   119  func startServer() {
   120  	http.Handle("/echo", Handler(echoServer))
   121  	http.Handle("/count", Handler(countServer))
   122  	http.Handle("/ctrldata", Handler(ctrlAndDataServer))
   123  	subproto := Server{
   124  		Handshake: subProtocolHandshake,
   125  		Handler:   Handler(subProtoServer),
   126  	}
   127  	http.Handle("/subproto", subproto)
   128  	server := httptest.NewServer(nil)
   129  	serverAddr = server.Listener.Addr().String()
   130  	log.Print("Test WebSocket server listening on ", serverAddr)
   131  }
   132  
   133  func newConfig(t *testing.T, path string) *Config {
   134  	config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
   135  	return config
   136  }
   137  
   138  func TestEcho(t *testing.T) {
   139  	once.Do(startServer)
   140  
   141  	// websocket.Dial()
   142  	client, err := net.Dial("tcp", serverAddr)
   143  	if err != nil {
   144  		t.Fatal("dialing", err)
   145  	}
   146  	conn, err := NewClient(newConfig(t, "/echo"), client)
   147  	if err != nil {
   148  		t.Errorf("WebSocket handshake error: %v", err)
   149  		return
   150  	}
   151  
   152  	msg := []byte("hello, world\n")
   153  	if _, err := conn.Write(msg); err != nil {
   154  		t.Errorf("Write: %v", err)
   155  	}
   156  	var actual_msg = make([]byte, 512)
   157  	n, err := conn.Read(actual_msg)
   158  	if err != nil {
   159  		t.Errorf("Read: %v", err)
   160  	}
   161  	actual_msg = actual_msg[0:n]
   162  	if !bytes.Equal(msg, actual_msg) {
   163  		t.Errorf("Echo: expected %q got %q", msg, actual_msg)
   164  	}
   165  	conn.Close()
   166  }
   167  
   168  func TestAddr(t *testing.T) {
   169  	once.Do(startServer)
   170  
   171  	// websocket.Dial()
   172  	client, err := net.Dial("tcp", serverAddr)
   173  	if err != nil {
   174  		t.Fatal("dialing", err)
   175  	}
   176  	conn, err := NewClient(newConfig(t, "/echo"), client)
   177  	if err != nil {
   178  		t.Errorf("WebSocket handshake error: %v", err)
   179  		return
   180  	}
   181  
   182  	ra := conn.RemoteAddr().String()
   183  	if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
   184  		t.Errorf("Bad remote addr: %v", ra)
   185  	}
   186  	la := conn.LocalAddr().String()
   187  	if !strings.HasPrefix(la, "http://") {
   188  		t.Errorf("Bad local addr: %v", la)
   189  	}
   190  	conn.Close()
   191  }
   192  
   193  func TestCount(t *testing.T) {
   194  	once.Do(startServer)
   195  
   196  	// websocket.Dial()
   197  	client, err := net.Dial("tcp", serverAddr)
   198  	if err != nil {
   199  		t.Fatal("dialing", err)
   200  	}
   201  	conn, err := NewClient(newConfig(t, "/count"), client)
   202  	if err != nil {
   203  		t.Errorf("WebSocket handshake error: %v", err)
   204  		return
   205  	}
   206  
   207  	var count Count
   208  	count.S = "hello"
   209  	if err := JSON.Send(conn, count); err != nil {
   210  		t.Errorf("Write: %v", err)
   211  	}
   212  	if err := JSON.Receive(conn, &count); err != nil {
   213  		t.Errorf("Read: %v", err)
   214  	}
   215  	if count.N != 1 {
   216  		t.Errorf("count: expected %d got %d", 1, count.N)
   217  	}
   218  	if count.S != "hello" {
   219  		t.Errorf("count: expected %q got %q", "hello", count.S)
   220  	}
   221  	if err := JSON.Send(conn, count); err != nil {
   222  		t.Errorf("Write: %v", err)
   223  	}
   224  	if err := JSON.Receive(conn, &count); err != nil {
   225  		t.Errorf("Read: %v", err)
   226  	}
   227  	if count.N != 2 {
   228  		t.Errorf("count: expected %d got %d", 2, count.N)
   229  	}
   230  	if count.S != "hellohello" {
   231  		t.Errorf("count: expected %q got %q", "hellohello", count.S)
   232  	}
   233  	conn.Close()
   234  }
   235  
   236  func TestWithQuery(t *testing.T) {
   237  	once.Do(startServer)
   238  
   239  	client, err := net.Dial("tcp", serverAddr)
   240  	if err != nil {
   241  		t.Fatal("dialing", err)
   242  	}
   243  
   244  	config := newConfig(t, "/echo")
   245  	config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
   246  	if err != nil {
   247  		t.Fatal("location url", err)
   248  	}
   249  
   250  	ws, err := NewClient(config, client)
   251  	if err != nil {
   252  		t.Errorf("WebSocket handshake: %v", err)
   253  		return
   254  	}
   255  	ws.Close()
   256  }
   257  
   258  func testWithProtocol(t *testing.T, subproto []string) (string, error) {
   259  	once.Do(startServer)
   260  
   261  	client, err := net.Dial("tcp", serverAddr)
   262  	if err != nil {
   263  		t.Fatal("dialing", err)
   264  	}
   265  
   266  	config := newConfig(t, "/subproto")
   267  	config.Protocol = subproto
   268  
   269  	ws, err := NewClient(config, client)
   270  	if err != nil {
   271  		return "", err
   272  	}
   273  	msg := make([]byte, 16)
   274  	n, err := ws.Read(msg)
   275  	if err != nil {
   276  		return "", err
   277  	}
   278  	ws.Close()
   279  	return string(msg[:n]), nil
   280  }
   281  
   282  func TestWithProtocol(t *testing.T) {
   283  	proto, err := testWithProtocol(t, []string{"chat"})
   284  	if err != nil {
   285  		t.Errorf("SubProto: unexpected error: %v", err)
   286  	}
   287  	if proto != "chat" {
   288  		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
   289  	}
   290  }
   291  
   292  func TestWithTwoProtocol(t *testing.T) {
   293  	proto, err := testWithProtocol(t, []string{"test", "chat"})
   294  	if err != nil {
   295  		t.Errorf("SubProto: unexpected error: %v", err)
   296  	}
   297  	if proto != "chat" {
   298  		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
   299  	}
   300  }
   301  
   302  func TestWithBadProtocol(t *testing.T) {
   303  	_, err := testWithProtocol(t, []string{"test"})
   304  	if err != ErrBadStatus {
   305  		t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
   306  	}
   307  }
   308  
   309  func TestHTTP(t *testing.T) {
   310  	once.Do(startServer)
   311  
   312  	// If the client did not send a handshake that matches the protocol
   313  	// specification, the server MUST return an HTTP response with an
   314  	// appropriate error code (such as 400 Bad Request)
   315  	resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
   316  	if err != nil {
   317  		t.Errorf("Get: error %#v", err)
   318  		return
   319  	}
   320  	if resp == nil {
   321  		t.Error("Get: resp is null")
   322  		return
   323  	}
   324  	if resp.StatusCode != http.StatusBadRequest {
   325  		t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
   326  	}
   327  }
   328  
   329  func TestTrailingSpaces(t *testing.T) {
   330  	// http://code.google.com/p/go/issues/detail?id=955
   331  	// The last runs of this create keys with trailing spaces that should not be
   332  	// generated by the client.
   333  	once.Do(startServer)
   334  	config := newConfig(t, "/echo")
   335  	for i := 0; i < 30; i++ {
   336  		// body
   337  		ws, err := DialConfig(config)
   338  		if err != nil {
   339  			t.Errorf("Dial #%d failed: %v", i, err)
   340  			break
   341  		}
   342  		ws.Close()
   343  	}
   344  }
   345  
   346  func TestDialConfigBadVersion(t *testing.T) {
   347  	once.Do(startServer)
   348  	config := newConfig(t, "/echo")
   349  	config.Version = 1234
   350  
   351  	_, err := DialConfig(config)
   352  
   353  	if dialerr, ok := err.(*DialError); ok {
   354  		if dialerr.Err != ErrBadProtocolVersion {
   355  			t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
   356  		}
   357  	}
   358  }
   359  
   360  func TestSmallBuffer(t *testing.T) {
   361  	// http://code.google.com/p/go/issues/detail?id=1145
   362  	// Read should be able to handle reading a fragment of a frame.
   363  	once.Do(startServer)
   364  
   365  	// websocket.Dial()
   366  	client, err := net.Dial("tcp", serverAddr)
   367  	if err != nil {
   368  		t.Fatal("dialing", err)
   369  	}
   370  	conn, err := NewClient(newConfig(t, "/echo"), client)
   371  	if err != nil {
   372  		t.Errorf("WebSocket handshake error: %v", err)
   373  		return
   374  	}
   375  
   376  	msg := []byte("hello, world\n")
   377  	if _, err := conn.Write(msg); err != nil {
   378  		t.Errorf("Write: %v", err)
   379  	}
   380  	var small_msg = make([]byte, 8)
   381  	n, err := conn.Read(small_msg)
   382  	if err != nil {
   383  		t.Errorf("Read: %v", err)
   384  	}
   385  	if !bytes.Equal(msg[:len(small_msg)], small_msg) {
   386  		t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
   387  	}
   388  	var second_msg = make([]byte, len(msg))
   389  	n, err = conn.Read(second_msg)
   390  	if err != nil {
   391  		t.Errorf("Read: %v", err)
   392  	}
   393  	second_msg = second_msg[0:n]
   394  	if !bytes.Equal(msg[len(small_msg):], second_msg) {
   395  		t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
   396  	}
   397  	conn.Close()
   398  }
   399  
   400  var parseAuthorityTests = []struct {
   401  	in  *url.URL
   402  	out string
   403  }{
   404  	{
   405  		&url.URL{
   406  			Scheme: "ws",
   407  			Host:   "www.google.com",
   408  		},
   409  		"www.google.com:80",
   410  	},
   411  	{
   412  		&url.URL{
   413  			Scheme: "wss",
   414  			Host:   "www.google.com",
   415  		},
   416  		"www.google.com:443",
   417  	},
   418  	{
   419  		&url.URL{
   420  			Scheme: "ws",
   421  			Host:   "www.google.com:80",
   422  		},
   423  		"www.google.com:80",
   424  	},
   425  	{
   426  		&url.URL{
   427  			Scheme: "wss",
   428  			Host:   "www.google.com:443",
   429  		},
   430  		"www.google.com:443",
   431  	},
   432  	// some invalid ones for parseAuthority. parseAuthority doesn't
   433  	// concern itself with the scheme unless it actually knows about it
   434  	{
   435  		&url.URL{
   436  			Scheme: "http",
   437  			Host:   "www.google.com",
   438  		},
   439  		"www.google.com",
   440  	},
   441  	{
   442  		&url.URL{
   443  			Scheme: "http",
   444  			Host:   "www.google.com:80",
   445  		},
   446  		"www.google.com:80",
   447  	},
   448  	{
   449  		&url.URL{
   450  			Scheme: "asdf",
   451  			Host:   "127.0.0.1",
   452  		},
   453  		"127.0.0.1",
   454  	},
   455  	{
   456  		&url.URL{
   457  			Scheme: "asdf",
   458  			Host:   "www.google.com",
   459  		},
   460  		"www.google.com",
   461  	},
   462  }
   463  
   464  func TestParseAuthority(t *testing.T) {
   465  	for _, tt := range parseAuthorityTests {
   466  		out := parseAuthority(tt.in)
   467  		if out != tt.out {
   468  			t.Errorf("got %v; want %v", out, tt.out)
   469  		}
   470  	}
   471  }
   472  
   473  type closerConn struct {
   474  	net.Conn
   475  	closed int // count of the number of times Close was called
   476  }
   477  
   478  func (c *closerConn) Close() error {
   479  	c.closed++
   480  	return c.Conn.Close()
   481  }
   482  
   483  func TestClose(t *testing.T) {
   484  	if runtime.GOOS == "plan9" {
   485  		t.Skip("see golang.org/issue/11454")
   486  	}
   487  
   488  	once.Do(startServer)
   489  
   490  	conn, err := net.Dial("tcp", serverAddr)
   491  	if err != nil {
   492  		t.Fatal("dialing", err)
   493  	}
   494  
   495  	cc := closerConn{Conn: conn}
   496  
   497  	client, err := NewClient(newConfig(t, "/echo"), &cc)
   498  	if err != nil {
   499  		t.Fatalf("WebSocket handshake: %v", err)
   500  	}
   501  
   502  	// set the deadline to ten minutes ago, which will have expired by the time
   503  	// client.Close sends the close status frame.
   504  	conn.SetDeadline(time.Now().Add(-10 * time.Minute))
   505  
   506  	if err := client.Close(); err == nil {
   507  		t.Errorf("ws.Close(): expected error, got %v", err)
   508  	}
   509  	if cc.closed < 1 {
   510  		t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
   511  	}
   512  }
   513  
   514  var originTests = []struct {
   515  	req    *http.Request
   516  	origin *url.URL
   517  }{
   518  	{
   519  		req: &http.Request{
   520  			Header: http.Header{
   521  				"Origin": []string{"http://www.example.com"},
   522  			},
   523  		},
   524  		origin: &url.URL{
   525  			Scheme: "http",
   526  			Host:   "www.example.com",
   527  		},
   528  	},
   529  	{
   530  		req: &http.Request{},
   531  	},
   532  }
   533  
   534  func TestOrigin(t *testing.T) {
   535  	conf := newConfig(t, "/echo")
   536  	conf.Version = ProtocolVersionHybi13
   537  	for i, tt := range originTests {
   538  		origin, err := Origin(conf, tt.req)
   539  		if err != nil {
   540  			t.Error(err)
   541  			continue
   542  		}
   543  		if !reflect.DeepEqual(origin, tt.origin) {
   544  			t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
   545  			continue
   546  		}
   547  	}
   548  }
   549  
   550  func TestCtrlAndData(t *testing.T) {
   551  	once.Do(startServer)
   552  
   553  	c, err := net.Dial("tcp", serverAddr)
   554  	if err != nil {
   555  		t.Fatal(err)
   556  	}
   557  	ws, err := NewClient(newConfig(t, "/ctrldata"), c)
   558  	if err != nil {
   559  		t.Fatal(err)
   560  	}
   561  	defer ws.Close()
   562  
   563  	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
   564  	ws.frameHandler = h
   565  
   566  	b := make([]byte, 128)
   567  	for i := 0; i < 2; i++ {
   568  		data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
   569  		if _, err := ws.Write(data); err != nil {
   570  			t.Fatalf("#%d: %v", i, err)
   571  		}
   572  		var ctrl []byte
   573  		if i%2 != 0 { // with or without payload
   574  			ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
   575  		}
   576  		if _, err := h.WritePing(ctrl); err != nil {
   577  			t.Fatalf("#%d: %v", i, err)
   578  		}
   579  		n, err := ws.Read(b)
   580  		if err != nil {
   581  			t.Fatalf("#%d: %v", i, err)
   582  		}
   583  		if !bytes.Equal(b[:n], data) {
   584  			t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
   585  		}
   586  	}
   587  }