github.com/hxx258456/ccgo@v0.0.5-0.20230213014102-48b35f46f66f/net/websocket/websocket_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"net/url"
    15  	"reflect"
    16  	"runtime"
    17  	"strings"
    18  	"sync"
    19  	"testing"
    20  	"time"
    21  
    22  	"github.com/hxx258456/ccgo/gmhttp/httptest"
    23  
    24  	http "github.com/hxx258456/ccgo/gmhttp"
    25  )
    26  
    27  var serverAddr string
    28  var once sync.Once
    29  
    30  func echoServer(ws *Conn) {
    31  	defer ws.Close()
    32  	io.Copy(ws, ws)
    33  }
    34  
    35  type Count struct {
    36  	S string
    37  	N int
    38  }
    39  
    40  func countServer(ws *Conn) {
    41  	defer ws.Close()
    42  	for {
    43  		var count Count
    44  		err := JSON.Receive(ws, &count)
    45  		if err != nil {
    46  			return
    47  		}
    48  		count.N++
    49  		count.S = strings.Repeat(count.S, count.N)
    50  		err = JSON.Send(ws, count)
    51  		if err != nil {
    52  			return
    53  		}
    54  	}
    55  }
    56  
    57  type testCtrlAndDataHandler struct {
    58  	hybiFrameHandler
    59  }
    60  
    61  func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
    62  	h.hybiFrameHandler.conn.wio.Lock()
    63  	defer h.hybiFrameHandler.conn.wio.Unlock()
    64  	w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
    65  	if err != nil {
    66  		return 0, err
    67  	}
    68  	n, err := w.Write(b)
    69  	w.Close()
    70  	return n, err
    71  }
    72  
    73  func ctrlAndDataServer(ws *Conn) {
    74  	defer ws.Close()
    75  	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
    76  	ws.frameHandler = h
    77  
    78  	go func() {
    79  		for i := 0; ; i++ {
    80  			var b []byte
    81  			if i%2 != 0 { // with or without payload
    82  				b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
    83  			}
    84  			if _, err := h.WritePing(b); err != nil {
    85  				break
    86  			}
    87  			if _, err := h.WritePong(b); err != nil { // unsolicited pong
    88  				break
    89  			}
    90  			time.Sleep(10 * time.Millisecond)
    91  		}
    92  	}()
    93  
    94  	b := make([]byte, 128)
    95  	for {
    96  		n, err := ws.Read(b)
    97  		if err != nil {
    98  			break
    99  		}
   100  		if _, err := ws.Write(b[:n]); err != nil {
   101  			break
   102  		}
   103  	}
   104  }
   105  
   106  func subProtocolHandshake(config *Config, req *http.Request) error {
   107  	for _, proto := range config.Protocol {
   108  		if proto == "chat" {
   109  			config.Protocol = []string{proto}
   110  			return nil
   111  		}
   112  	}
   113  	return ErrBadWebSocketProtocol
   114  }
   115  
   116  func subProtoServer(ws *Conn) {
   117  	for _, proto := range ws.Config().Protocol {
   118  		io.WriteString(ws, proto)
   119  	}
   120  }
   121  
   122  func startServer() {
   123  	http.Handle("/echo", Handler(echoServer))
   124  	http.Handle("/count", Handler(countServer))
   125  	http.Handle("/ctrldata", Handler(ctrlAndDataServer))
   126  	subproto := Server{
   127  		Handshake: subProtocolHandshake,
   128  		Handler:   Handler(subProtoServer),
   129  	}
   130  	http.Handle("/subproto", subproto)
   131  	server := httptest.NewServer(nil)
   132  	serverAddr = server.Listener.Addr().String()
   133  	log.Print("Test WebSocket server listening on ", serverAddr)
   134  }
   135  
   136  func newConfig(t *testing.T, path string) *Config {
   137  	config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
   138  	return config
   139  }
   140  
   141  func TestEcho(t *testing.T) {
   142  	once.Do(startServer)
   143  
   144  	// websocket.Dial()
   145  	client, err := net.Dial("tcp", serverAddr)
   146  	if err != nil {
   147  		t.Fatal("dialing", err)
   148  	}
   149  	conn, err := NewClient(newConfig(t, "/echo"), client)
   150  	if err != nil {
   151  		t.Errorf("WebSocket handshake error: %v", err)
   152  		return
   153  	}
   154  
   155  	msg := []byte("hello, world\n")
   156  	if _, err := conn.Write(msg); err != nil {
   157  		t.Errorf("Write: %v", err)
   158  	}
   159  	var actual_msg = make([]byte, 512)
   160  	n, err := conn.Read(actual_msg)
   161  	if err != nil {
   162  		t.Errorf("Read: %v", err)
   163  	}
   164  	actual_msg = actual_msg[0:n]
   165  	if !bytes.Equal(msg, actual_msg) {
   166  		t.Errorf("Echo: expected %q got %q", msg, actual_msg)
   167  	}
   168  	conn.Close()
   169  }
   170  
   171  func TestAddr(t *testing.T) {
   172  	once.Do(startServer)
   173  
   174  	// websocket.Dial()
   175  	client, err := net.Dial("tcp", serverAddr)
   176  	if err != nil {
   177  		t.Fatal("dialing", err)
   178  	}
   179  	conn, err := NewClient(newConfig(t, "/echo"), client)
   180  	if err != nil {
   181  		t.Errorf("WebSocket handshake error: %v", err)
   182  		return
   183  	}
   184  
   185  	ra := conn.RemoteAddr().String()
   186  	if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
   187  		t.Errorf("Bad remote addr: %v", ra)
   188  	}
   189  	la := conn.LocalAddr().String()
   190  	if !strings.HasPrefix(la, "http://") {
   191  		t.Errorf("Bad local addr: %v", la)
   192  	}
   193  	conn.Close()
   194  }
   195  
   196  func TestCount(t *testing.T) {
   197  	once.Do(startServer)
   198  
   199  	// websocket.Dial()
   200  	client, err := net.Dial("tcp", serverAddr)
   201  	if err != nil {
   202  		t.Fatal("dialing", err)
   203  	}
   204  	conn, err := NewClient(newConfig(t, "/count"), client)
   205  	if err != nil {
   206  		t.Errorf("WebSocket handshake error: %v", err)
   207  		return
   208  	}
   209  
   210  	var count Count
   211  	count.S = "hello"
   212  	if err := JSON.Send(conn, count); err != nil {
   213  		t.Errorf("Write: %v", err)
   214  	}
   215  	if err := JSON.Receive(conn, &count); err != nil {
   216  		t.Errorf("Read: %v", err)
   217  	}
   218  	if count.N != 1 {
   219  		t.Errorf("count: expected %d got %d", 1, count.N)
   220  	}
   221  	if count.S != "hello" {
   222  		t.Errorf("count: expected %q got %q", "hello", count.S)
   223  	}
   224  	if err := JSON.Send(conn, count); err != nil {
   225  		t.Errorf("Write: %v", err)
   226  	}
   227  	if err := JSON.Receive(conn, &count); err != nil {
   228  		t.Errorf("Read: %v", err)
   229  	}
   230  	if count.N != 2 {
   231  		t.Errorf("count: expected %d got %d", 2, count.N)
   232  	}
   233  	if count.S != "hellohello" {
   234  		t.Errorf("count: expected %q got %q", "hellohello", count.S)
   235  	}
   236  	conn.Close()
   237  }
   238  
   239  func TestWithQuery(t *testing.T) {
   240  	once.Do(startServer)
   241  
   242  	client, err := net.Dial("tcp", serverAddr)
   243  	if err != nil {
   244  		t.Fatal("dialing", err)
   245  	}
   246  
   247  	config := newConfig(t, "/echo")
   248  	config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
   249  	if err != nil {
   250  		t.Fatal("location url", err)
   251  	}
   252  
   253  	ws, err := NewClient(config, client)
   254  	if err != nil {
   255  		t.Errorf("WebSocket handshake: %v", err)
   256  		return
   257  	}
   258  	ws.Close()
   259  }
   260  
   261  func testWithProtocol(t *testing.T, subproto []string) (string, error) {
   262  	once.Do(startServer)
   263  
   264  	client, err := net.Dial("tcp", serverAddr)
   265  	if err != nil {
   266  		t.Fatal("dialing", err)
   267  	}
   268  
   269  	config := newConfig(t, "/subproto")
   270  	config.Protocol = subproto
   271  
   272  	ws, err := NewClient(config, client)
   273  	if err != nil {
   274  		return "", err
   275  	}
   276  	msg := make([]byte, 16)
   277  	n, err := ws.Read(msg)
   278  	if err != nil {
   279  		return "", err
   280  	}
   281  	ws.Close()
   282  	return string(msg[:n]), nil
   283  }
   284  
   285  func TestWithProtocol(t *testing.T) {
   286  	proto, err := testWithProtocol(t, []string{"chat"})
   287  	if err != nil {
   288  		t.Errorf("SubProto: unexpected error: %v", err)
   289  	}
   290  	if proto != "chat" {
   291  		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
   292  	}
   293  }
   294  
   295  func TestWithTwoProtocol(t *testing.T) {
   296  	proto, err := testWithProtocol(t, []string{"test", "chat"})
   297  	if err != nil {
   298  		t.Errorf("SubProto: unexpected error: %v", err)
   299  	}
   300  	if proto != "chat" {
   301  		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
   302  	}
   303  }
   304  
   305  func TestWithBadProtocol(t *testing.T) {
   306  	_, err := testWithProtocol(t, []string{"test"})
   307  	if err != ErrBadStatus {
   308  		t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
   309  	}
   310  }
   311  
   312  func TestHTTP(t *testing.T) {
   313  	once.Do(startServer)
   314  
   315  	// If the client did not send a handshake that matches the protocol
   316  	// specification, the server MUST return an HTTP response with an
   317  	// appropriate error code (such as 400 Bad Request)
   318  	resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
   319  	if err != nil {
   320  		t.Errorf("Get: error %#v", err)
   321  		return
   322  	}
   323  	if resp == nil {
   324  		t.Error("Get: resp is null")
   325  		return
   326  	}
   327  	if resp.StatusCode != http.StatusBadRequest {
   328  		t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
   329  	}
   330  }
   331  
   332  func TestTrailingSpaces(t *testing.T) {
   333  	// http://code.google.com/p/go/issues/detail?id=955
   334  	// The last runs of this create keys with trailing spaces that should not be
   335  	// generated by the client.
   336  	once.Do(startServer)
   337  	config := newConfig(t, "/echo")
   338  	for i := 0; i < 30; i++ {
   339  		// body
   340  		ws, err := DialConfig(config)
   341  		if err != nil {
   342  			t.Errorf("Dial #%d failed: %v", i, err)
   343  			break
   344  		}
   345  		ws.Close()
   346  	}
   347  }
   348  
   349  func TestDialConfigBadVersion(t *testing.T) {
   350  	once.Do(startServer)
   351  	config := newConfig(t, "/echo")
   352  	config.Version = 1234
   353  
   354  	_, err := DialConfig(config)
   355  
   356  	if dialerr, ok := err.(*DialError); ok {
   357  		if dialerr.Err != ErrBadProtocolVersion {
   358  			t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
   359  		}
   360  	}
   361  }
   362  
   363  func TestDialConfigWithDialer(t *testing.T) {
   364  	once.Do(startServer)
   365  	config := newConfig(t, "/echo")
   366  	config.Dialer = &net.Dialer{
   367  		Deadline: time.Now().Add(-time.Minute),
   368  	}
   369  	_, err := DialConfig(config)
   370  	dialerr, ok := err.(*DialError)
   371  	if !ok {
   372  		t.Fatalf("DialError expected, got %#v", err)
   373  	}
   374  	neterr, ok := dialerr.Err.(*net.OpError)
   375  	if !ok {
   376  		t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
   377  	}
   378  	if !neterr.Timeout() {
   379  		t.Fatalf("expected timeout error, got %#v", neterr)
   380  	}
   381  }
   382  
   383  func TestSmallBuffer(t *testing.T) {
   384  	// http://code.google.com/p/go/issues/detail?id=1145
   385  	// Read should be able to handle reading a fragment of a frame.
   386  	once.Do(startServer)
   387  
   388  	// websocket.Dial()
   389  	client, err := net.Dial("tcp", serverAddr)
   390  	if err != nil {
   391  		t.Fatal("dialing", err)
   392  	}
   393  	conn, err := NewClient(newConfig(t, "/echo"), client)
   394  	if err != nil {
   395  		t.Errorf("WebSocket handshake error: %v", err)
   396  		return
   397  	}
   398  
   399  	msg := []byte("hello, world\n")
   400  	if _, err := conn.Write(msg); err != nil {
   401  		t.Errorf("Write: %v", err)
   402  	}
   403  	var small_msg = make([]byte, 8)
   404  	n, err := conn.Read(small_msg)
   405  	if err != nil {
   406  		t.Errorf("Read: %v", err)
   407  	}
   408  	if !bytes.Equal(msg[:len(small_msg)], small_msg) {
   409  		t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
   410  	}
   411  	var second_msg = make([]byte, len(msg))
   412  	n, err = conn.Read(second_msg)
   413  	if err != nil {
   414  		t.Errorf("Read: %v", err)
   415  	}
   416  	second_msg = second_msg[0:n]
   417  	if !bytes.Equal(msg[len(small_msg):], second_msg) {
   418  		t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
   419  	}
   420  	conn.Close()
   421  }
   422  
   423  var parseAuthorityTests = []struct {
   424  	in  *url.URL
   425  	out string
   426  }{
   427  	{
   428  		&url.URL{
   429  			Scheme: "ws",
   430  			Host:   "www.google.com",
   431  		},
   432  		"www.google.com:80",
   433  	},
   434  	{
   435  		&url.URL{
   436  			Scheme: "wss",
   437  			Host:   "www.google.com",
   438  		},
   439  		"www.google.com:443",
   440  	},
   441  	{
   442  		&url.URL{
   443  			Scheme: "ws",
   444  			Host:   "www.google.com:80",
   445  		},
   446  		"www.google.com:80",
   447  	},
   448  	{
   449  		&url.URL{
   450  			Scheme: "wss",
   451  			Host:   "www.google.com:443",
   452  		},
   453  		"www.google.com:443",
   454  	},
   455  	// some invalid ones for parseAuthority. parseAuthority doesn't
   456  	// concern itself with the scheme unless it actually knows about it
   457  	{
   458  		&url.URL{
   459  			Scheme: "http",
   460  			Host:   "www.google.com",
   461  		},
   462  		"www.google.com",
   463  	},
   464  	{
   465  		&url.URL{
   466  			Scheme: "http",
   467  			Host:   "www.google.com:80",
   468  		},
   469  		"www.google.com:80",
   470  	},
   471  	{
   472  		&url.URL{
   473  			Scheme: "asdf",
   474  			Host:   "127.0.0.1",
   475  		},
   476  		"127.0.0.1",
   477  	},
   478  	{
   479  		&url.URL{
   480  			Scheme: "asdf",
   481  			Host:   "www.google.com",
   482  		},
   483  		"www.google.com",
   484  	},
   485  }
   486  
   487  func TestParseAuthority(t *testing.T) {
   488  	for _, tt := range parseAuthorityTests {
   489  		out := parseAuthority(tt.in)
   490  		if out != tt.out {
   491  			t.Errorf("got %v; want %v", out, tt.out)
   492  		}
   493  	}
   494  }
   495  
   496  type closerConn struct {
   497  	net.Conn
   498  	closed int // count of the number of times Close was called
   499  }
   500  
   501  func (c *closerConn) Close() error {
   502  	c.closed++
   503  	return c.Conn.Close()
   504  }
   505  
   506  func TestClose(t *testing.T) {
   507  	if runtime.GOOS == "plan9" {
   508  		t.Skip("see golang.org/issue/11454")
   509  	}
   510  
   511  	once.Do(startServer)
   512  
   513  	conn, err := net.Dial("tcp", serverAddr)
   514  	if err != nil {
   515  		t.Fatal("dialing", err)
   516  	}
   517  
   518  	cc := closerConn{Conn: conn}
   519  
   520  	client, err := NewClient(newConfig(t, "/echo"), &cc)
   521  	if err != nil {
   522  		t.Fatalf("WebSocket handshake: %v", err)
   523  	}
   524  
   525  	// set the deadline to ten minutes ago, which will have expired by the time
   526  	// client.Close sends the close status frame.
   527  	conn.SetDeadline(time.Now().Add(-10 * time.Minute))
   528  
   529  	if err := client.Close(); err == nil {
   530  		t.Errorf("ws.Close(): expected error, got %v", err)
   531  	}
   532  	if cc.closed < 1 {
   533  		t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
   534  	}
   535  }
   536  
   537  var originTests = []struct {
   538  	req    *http.Request
   539  	origin *url.URL
   540  }{
   541  	{
   542  		req: &http.Request{
   543  			Header: http.Header{
   544  				"Origin": []string{"http://www.example.com"},
   545  			},
   546  		},
   547  		origin: &url.URL{
   548  			Scheme: "http",
   549  			Host:   "www.example.com",
   550  		},
   551  	},
   552  	{
   553  		req: &http.Request{},
   554  	},
   555  }
   556  
   557  func TestOrigin(t *testing.T) {
   558  	conf := newConfig(t, "/echo")
   559  	conf.Version = ProtocolVersionHybi13
   560  	for i, tt := range originTests {
   561  		origin, err := Origin(conf, tt.req)
   562  		if err != nil {
   563  			t.Error(err)
   564  			continue
   565  		}
   566  		if !reflect.DeepEqual(origin, tt.origin) {
   567  			t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
   568  			continue
   569  		}
   570  	}
   571  }
   572  
   573  func TestCtrlAndData(t *testing.T) {
   574  	once.Do(startServer)
   575  
   576  	c, err := net.Dial("tcp", serverAddr)
   577  	if err != nil {
   578  		t.Fatal(err)
   579  	}
   580  	ws, err := NewClient(newConfig(t, "/ctrldata"), c)
   581  	if err != nil {
   582  		t.Fatal(err)
   583  	}
   584  	defer ws.Close()
   585  
   586  	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
   587  	ws.frameHandler = h
   588  
   589  	b := make([]byte, 128)
   590  	for i := 0; i < 2; i++ {
   591  		data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
   592  		if _, err := ws.Write(data); err != nil {
   593  			t.Fatalf("#%d: %v", i, err)
   594  		}
   595  		var ctrl []byte
   596  		if i%2 != 0 { // with or without payload
   597  			ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
   598  		}
   599  		if _, err := h.WritePing(ctrl); err != nil {
   600  			t.Fatalf("#%d: %v", i, err)
   601  		}
   602  		n, err := ws.Read(b)
   603  		if err != nil {
   604  			t.Fatalf("#%d: %v", i, err)
   605  		}
   606  		if !bytes.Equal(b[:n], data) {
   607  			t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
   608  		}
   609  	}
   610  }
   611  
   612  func TestCodec_ReceiveLimited(t *testing.T) {
   613  	const limit = 2048
   614  	var payloads [][]byte
   615  	for _, size := range []int{
   616  		1024,
   617  		2048,
   618  		4096, // receive of this message would be interrupted due to limit
   619  		2048, // this one is to make sure next receive recovers discarding leftovers
   620  	} {
   621  		b := make([]byte, size)
   622  		rand.Read(b)
   623  		payloads = append(payloads, b)
   624  	}
   625  	handlerDone := make(chan struct{})
   626  	limitedHandler := func(ws *Conn) {
   627  		defer close(handlerDone)
   628  		ws.MaxPayloadBytes = limit
   629  		defer ws.Close()
   630  		for i, p := range payloads {
   631  			t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
   632  			var recv []byte
   633  			err := Message.Receive(ws, &recv)
   634  			switch err {
   635  			case nil:
   636  			case ErrFrameTooLarge:
   637  				if len(p) <= limit {
   638  					t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
   639  				}
   640  				continue
   641  			default:
   642  				t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
   643  			}
   644  			if len(recv) > limit {
   645  				t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
   646  			}
   647  			if !bytes.Equal(p, recv) {
   648  				t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
   649  			}
   650  		}
   651  	}
   652  	server := httptest.NewServer(Handler(limitedHandler))
   653  	defer server.CloseClientConnections()
   654  	defer server.Close()
   655  	addr := server.Listener.Addr().String()
   656  	ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
   657  	if err != nil {
   658  		t.Fatal(err)
   659  	}
   660  	defer ws.Close()
   661  	for i, p := range payloads {
   662  		if err := Message.Send(ws, p); err != nil {
   663  			t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
   664  		}
   665  	}
   666  	<-handlerDone
   667  }