golang.org/x/net@v0.25.1-0.20240516223405-c87a5b62e243/websocket/websocket_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bytes"
     9  	"crypto/rand"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net"
    14  	"net/http"
    15  	"net/http/httptest"
    16  	"net/url"
    17  	"reflect"
    18  	"runtime"
    19  	"strings"
    20  	"sync"
    21  	"testing"
    22  	"time"
    23  )
    24  
    25  var serverAddr string
    26  var once sync.Once
    27  
    28  func echoServer(ws *Conn) {
    29  	defer ws.Close()
    30  	io.Copy(ws, ws)
    31  }
    32  
    33  type Count struct {
    34  	S string
    35  	N int
    36  }
    37  
    38  func countServer(ws *Conn) {
    39  	defer ws.Close()
    40  	for {
    41  		var count Count
    42  		err := JSON.Receive(ws, &count)
    43  		if err != nil {
    44  			return
    45  		}
    46  		count.N++
    47  		count.S = strings.Repeat(count.S, count.N)
    48  		err = JSON.Send(ws, count)
    49  		if err != nil {
    50  			return
    51  		}
    52  	}
    53  }
    54  
    55  type testCtrlAndDataHandler struct {
    56  	hybiFrameHandler
    57  }
    58  
    59  func (h *testCtrlAndDataHandler) WritePing(b []byte) (int, error) {
    60  	h.hybiFrameHandler.conn.wio.Lock()
    61  	defer h.hybiFrameHandler.conn.wio.Unlock()
    62  	w, err := h.hybiFrameHandler.conn.frameWriterFactory.NewFrameWriter(PingFrame)
    63  	if err != nil {
    64  		return 0, err
    65  	}
    66  	n, err := w.Write(b)
    67  	w.Close()
    68  	return n, err
    69  }
    70  
    71  func ctrlAndDataServer(ws *Conn) {
    72  	defer ws.Close()
    73  	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
    74  	ws.frameHandler = h
    75  
    76  	go func() {
    77  		for i := 0; ; i++ {
    78  			var b []byte
    79  			if i%2 != 0 { // with or without payload
    80  				b = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-SERVER", i))
    81  			}
    82  			if _, err := h.WritePing(b); err != nil {
    83  				break
    84  			}
    85  			if _, err := h.WritePong(b); err != nil { // unsolicited pong
    86  				break
    87  			}
    88  			time.Sleep(10 * time.Millisecond)
    89  		}
    90  	}()
    91  
    92  	b := make([]byte, 128)
    93  	for {
    94  		n, err := ws.Read(b)
    95  		if err != nil {
    96  			break
    97  		}
    98  		if _, err := ws.Write(b[:n]); err != nil {
    99  			break
   100  		}
   101  	}
   102  }
   103  
   104  func subProtocolHandshake(config *Config, req *http.Request) error {
   105  	for _, proto := range config.Protocol {
   106  		if proto == "chat" {
   107  			config.Protocol = []string{proto}
   108  			return nil
   109  		}
   110  	}
   111  	return ErrBadWebSocketProtocol
   112  }
   113  
   114  func subProtoServer(ws *Conn) {
   115  	for _, proto := range ws.Config().Protocol {
   116  		io.WriteString(ws, proto)
   117  	}
   118  }
   119  
   120  func startServer() {
   121  	http.Handle("/echo", Handler(echoServer))
   122  	http.Handle("/count", Handler(countServer))
   123  	http.Handle("/ctrldata", Handler(ctrlAndDataServer))
   124  	subproto := Server{
   125  		Handshake: subProtocolHandshake,
   126  		Handler:   Handler(subProtoServer),
   127  	}
   128  	http.Handle("/subproto", subproto)
   129  	server := httptest.NewServer(nil)
   130  	serverAddr = server.Listener.Addr().String()
   131  	log.Print("Test WebSocket server listening on ", serverAddr)
   132  }
   133  
   134  func newConfig(t *testing.T, path string) *Config {
   135  	config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
   136  	return config
   137  }
   138  
   139  func TestEcho(t *testing.T) {
   140  	once.Do(startServer)
   141  
   142  	// websocket.Dial()
   143  	client, err := net.Dial("tcp", serverAddr)
   144  	if err != nil {
   145  		t.Fatal("dialing", err)
   146  	}
   147  	conn, err := NewClient(newConfig(t, "/echo"), client)
   148  	if err != nil {
   149  		t.Errorf("WebSocket handshake error: %v", err)
   150  		return
   151  	}
   152  
   153  	msg := []byte("hello, world\n")
   154  	if _, err := conn.Write(msg); err != nil {
   155  		t.Errorf("Write: %v", err)
   156  	}
   157  	var actual_msg = make([]byte, 512)
   158  	n, err := conn.Read(actual_msg)
   159  	if err != nil {
   160  		t.Errorf("Read: %v", err)
   161  	}
   162  	actual_msg = actual_msg[0:n]
   163  	if !bytes.Equal(msg, actual_msg) {
   164  		t.Errorf("Echo: expected %q got %q", msg, actual_msg)
   165  	}
   166  	conn.Close()
   167  }
   168  
   169  func TestAddr(t *testing.T) {
   170  	once.Do(startServer)
   171  
   172  	// websocket.Dial()
   173  	client, err := net.Dial("tcp", serverAddr)
   174  	if err != nil {
   175  		t.Fatal("dialing", err)
   176  	}
   177  	conn, err := NewClient(newConfig(t, "/echo"), client)
   178  	if err != nil {
   179  		t.Errorf("WebSocket handshake error: %v", err)
   180  		return
   181  	}
   182  
   183  	ra := conn.RemoteAddr().String()
   184  	if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
   185  		t.Errorf("Bad remote addr: %v", ra)
   186  	}
   187  	la := conn.LocalAddr().String()
   188  	if !strings.HasPrefix(la, "http://") {
   189  		t.Errorf("Bad local addr: %v", la)
   190  	}
   191  	conn.Close()
   192  }
   193  
   194  func TestCount(t *testing.T) {
   195  	once.Do(startServer)
   196  
   197  	// websocket.Dial()
   198  	client, err := net.Dial("tcp", serverAddr)
   199  	if err != nil {
   200  		t.Fatal("dialing", err)
   201  	}
   202  	conn, err := NewClient(newConfig(t, "/count"), client)
   203  	if err != nil {
   204  		t.Errorf("WebSocket handshake error: %v", err)
   205  		return
   206  	}
   207  
   208  	var count Count
   209  	count.S = "hello"
   210  	if err := JSON.Send(conn, count); err != nil {
   211  		t.Errorf("Write: %v", err)
   212  	}
   213  	if err := JSON.Receive(conn, &count); err != nil {
   214  		t.Errorf("Read: %v", err)
   215  	}
   216  	if count.N != 1 {
   217  		t.Errorf("count: expected %d got %d", 1, count.N)
   218  	}
   219  	if count.S != "hello" {
   220  		t.Errorf("count: expected %q got %q", "hello", count.S)
   221  	}
   222  	if err := JSON.Send(conn, count); err != nil {
   223  		t.Errorf("Write: %v", err)
   224  	}
   225  	if err := JSON.Receive(conn, &count); err != nil {
   226  		t.Errorf("Read: %v", err)
   227  	}
   228  	if count.N != 2 {
   229  		t.Errorf("count: expected %d got %d", 2, count.N)
   230  	}
   231  	if count.S != "hellohello" {
   232  		t.Errorf("count: expected %q got %q", "hellohello", count.S)
   233  	}
   234  	conn.Close()
   235  }
   236  
   237  func TestWithQuery(t *testing.T) {
   238  	once.Do(startServer)
   239  
   240  	client, err := net.Dial("tcp", serverAddr)
   241  	if err != nil {
   242  		t.Fatal("dialing", err)
   243  	}
   244  
   245  	config := newConfig(t, "/echo")
   246  	config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
   247  	if err != nil {
   248  		t.Fatal("location url", err)
   249  	}
   250  
   251  	ws, err := NewClient(config, client)
   252  	if err != nil {
   253  		t.Errorf("WebSocket handshake: %v", err)
   254  		return
   255  	}
   256  	ws.Close()
   257  }
   258  
   259  func testWithProtocol(t *testing.T, subproto []string) (string, error) {
   260  	once.Do(startServer)
   261  
   262  	client, err := net.Dial("tcp", serverAddr)
   263  	if err != nil {
   264  		t.Fatal("dialing", err)
   265  	}
   266  
   267  	config := newConfig(t, "/subproto")
   268  	config.Protocol = subproto
   269  
   270  	ws, err := NewClient(config, client)
   271  	if err != nil {
   272  		return "", err
   273  	}
   274  	msg := make([]byte, 16)
   275  	n, err := ws.Read(msg)
   276  	if err != nil {
   277  		return "", err
   278  	}
   279  	ws.Close()
   280  	return string(msg[:n]), nil
   281  }
   282  
   283  func TestWithProtocol(t *testing.T) {
   284  	proto, err := testWithProtocol(t, []string{"chat"})
   285  	if err != nil {
   286  		t.Errorf("SubProto: unexpected error: %v", err)
   287  	}
   288  	if proto != "chat" {
   289  		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
   290  	}
   291  }
   292  
   293  func TestWithTwoProtocol(t *testing.T) {
   294  	proto, err := testWithProtocol(t, []string{"test", "chat"})
   295  	if err != nil {
   296  		t.Errorf("SubProto: unexpected error: %v", err)
   297  	}
   298  	if proto != "chat" {
   299  		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
   300  	}
   301  }
   302  
   303  func TestWithBadProtocol(t *testing.T) {
   304  	_, err := testWithProtocol(t, []string{"test"})
   305  	if err != ErrBadStatus {
   306  		t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
   307  	}
   308  }
   309  
   310  func TestHTTP(t *testing.T) {
   311  	once.Do(startServer)
   312  
   313  	// If the client did not send a handshake that matches the protocol
   314  	// specification, the server MUST return an HTTP response with an
   315  	// appropriate error code (such as 400 Bad Request)
   316  	resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
   317  	if err != nil {
   318  		t.Errorf("Get: error %#v", err)
   319  		return
   320  	}
   321  	if resp == nil {
   322  		t.Error("Get: resp is null")
   323  		return
   324  	}
   325  	if resp.StatusCode != http.StatusBadRequest {
   326  		t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
   327  	}
   328  }
   329  
   330  func TestTrailingSpaces(t *testing.T) {
   331  	// http://code.google.com/p/go/issues/detail?id=955
   332  	// The last runs of this create keys with trailing spaces that should not be
   333  	// generated by the client.
   334  	once.Do(startServer)
   335  	config := newConfig(t, "/echo")
   336  	for i := 0; i < 30; i++ {
   337  		// body
   338  		ws, err := DialConfig(config)
   339  		if err != nil {
   340  			t.Errorf("Dial #%d failed: %v", i, err)
   341  			break
   342  		}
   343  		ws.Close()
   344  	}
   345  }
   346  
   347  func TestDialConfigBadVersion(t *testing.T) {
   348  	once.Do(startServer)
   349  	config := newConfig(t, "/echo")
   350  	config.Version = 1234
   351  
   352  	_, err := DialConfig(config)
   353  
   354  	if dialerr, ok := err.(*DialError); ok {
   355  		if dialerr.Err != ErrBadProtocolVersion {
   356  			t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
   357  		}
   358  	}
   359  }
   360  
   361  func TestDialConfigWithDialer(t *testing.T) {
   362  	once.Do(startServer)
   363  	config := newConfig(t, "/echo")
   364  	config.Dialer = &net.Dialer{
   365  		Deadline: time.Now().Add(-time.Minute),
   366  	}
   367  	_, err := DialConfig(config)
   368  	dialerr, ok := err.(*DialError)
   369  	if !ok {
   370  		t.Fatalf("DialError expected, got %#v", err)
   371  	}
   372  	neterr, ok := dialerr.Err.(*net.OpError)
   373  	if !ok {
   374  		t.Fatalf("net.OpError error expected, got %#v", dialerr.Err)
   375  	}
   376  	if !neterr.Timeout() {
   377  		t.Fatalf("expected timeout error, got %#v", neterr)
   378  	}
   379  }
   380  
   381  func TestSmallBuffer(t *testing.T) {
   382  	// http://code.google.com/p/go/issues/detail?id=1145
   383  	// Read should be able to handle reading a fragment of a frame.
   384  	once.Do(startServer)
   385  
   386  	// websocket.Dial()
   387  	client, err := net.Dial("tcp", serverAddr)
   388  	if err != nil {
   389  		t.Fatal("dialing", err)
   390  	}
   391  	conn, err := NewClient(newConfig(t, "/echo"), client)
   392  	if err != nil {
   393  		t.Errorf("WebSocket handshake error: %v", err)
   394  		return
   395  	}
   396  
   397  	msg := []byte("hello, world\n")
   398  	if _, err := conn.Write(msg); err != nil {
   399  		t.Errorf("Write: %v", err)
   400  	}
   401  	var small_msg = make([]byte, 8)
   402  	n, err := conn.Read(small_msg)
   403  	if err != nil {
   404  		t.Errorf("Read: %v", err)
   405  	}
   406  	if !bytes.Equal(msg[:len(small_msg)], small_msg) {
   407  		t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
   408  	}
   409  	var second_msg = make([]byte, len(msg))
   410  	n, err = conn.Read(second_msg)
   411  	if err != nil {
   412  		t.Errorf("Read: %v", err)
   413  	}
   414  	second_msg = second_msg[0:n]
   415  	if !bytes.Equal(msg[len(small_msg):], second_msg) {
   416  		t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
   417  	}
   418  	conn.Close()
   419  }
   420  
   421  var parseAuthorityTests = []struct {
   422  	in  *url.URL
   423  	out string
   424  }{
   425  	{
   426  		&url.URL{
   427  			Scheme: "ws",
   428  			Host:   "www.google.com",
   429  		},
   430  		"www.google.com:80",
   431  	},
   432  	{
   433  		&url.URL{
   434  			Scheme: "wss",
   435  			Host:   "www.google.com",
   436  		},
   437  		"www.google.com:443",
   438  	},
   439  	{
   440  		&url.URL{
   441  			Scheme: "ws",
   442  			Host:   "www.google.com:80",
   443  		},
   444  		"www.google.com:80",
   445  	},
   446  	{
   447  		&url.URL{
   448  			Scheme: "wss",
   449  			Host:   "www.google.com:443",
   450  		},
   451  		"www.google.com:443",
   452  	},
   453  	// some invalid ones for parseAuthority. parseAuthority doesn't
   454  	// concern itself with the scheme unless it actually knows about it
   455  	{
   456  		&url.URL{
   457  			Scheme: "http",
   458  			Host:   "www.google.com",
   459  		},
   460  		"www.google.com",
   461  	},
   462  	{
   463  		&url.URL{
   464  			Scheme: "http",
   465  			Host:   "www.google.com:80",
   466  		},
   467  		"www.google.com:80",
   468  	},
   469  	{
   470  		&url.URL{
   471  			Scheme: "asdf",
   472  			Host:   "127.0.0.1",
   473  		},
   474  		"127.0.0.1",
   475  	},
   476  	{
   477  		&url.URL{
   478  			Scheme: "asdf",
   479  			Host:   "www.google.com",
   480  		},
   481  		"www.google.com",
   482  	},
   483  }
   484  
   485  func TestParseAuthority(t *testing.T) {
   486  	for _, tt := range parseAuthorityTests {
   487  		out := parseAuthority(tt.in)
   488  		if out != tt.out {
   489  			t.Errorf("got %v; want %v", out, tt.out)
   490  		}
   491  	}
   492  }
   493  
   494  type closerConn struct {
   495  	net.Conn
   496  	closed int // count of the number of times Close was called
   497  }
   498  
   499  func (c *closerConn) Close() error {
   500  	c.closed++
   501  	return c.Conn.Close()
   502  }
   503  
   504  func TestClose(t *testing.T) {
   505  	if runtime.GOOS == "plan9" {
   506  		t.Skip("see golang.org/issue/11454")
   507  	}
   508  
   509  	once.Do(startServer)
   510  
   511  	conn, err := net.Dial("tcp", serverAddr)
   512  	if err != nil {
   513  		t.Fatal("dialing", err)
   514  	}
   515  
   516  	cc := closerConn{Conn: conn}
   517  
   518  	client, err := NewClient(newConfig(t, "/echo"), &cc)
   519  	if err != nil {
   520  		t.Fatalf("WebSocket handshake: %v", err)
   521  	}
   522  
   523  	// set the deadline to ten minutes ago, which will have expired by the time
   524  	// client.Close sends the close status frame.
   525  	conn.SetDeadline(time.Now().Add(-10 * time.Minute))
   526  
   527  	if err := client.Close(); err == nil {
   528  		t.Errorf("ws.Close(): expected error, got %v", err)
   529  	}
   530  	if cc.closed < 1 {
   531  		t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
   532  	}
   533  }
   534  
   535  var originTests = []struct {
   536  	req    *http.Request
   537  	origin *url.URL
   538  }{
   539  	{
   540  		req: &http.Request{
   541  			Header: http.Header{
   542  				"Origin": []string{"http://www.example.com"},
   543  			},
   544  		},
   545  		origin: &url.URL{
   546  			Scheme: "http",
   547  			Host:   "www.example.com",
   548  		},
   549  	},
   550  	{
   551  		req: &http.Request{},
   552  	},
   553  }
   554  
   555  func TestOrigin(t *testing.T) {
   556  	conf := newConfig(t, "/echo")
   557  	conf.Version = ProtocolVersionHybi13
   558  	for i, tt := range originTests {
   559  		origin, err := Origin(conf, tt.req)
   560  		if err != nil {
   561  			t.Error(err)
   562  			continue
   563  		}
   564  		if !reflect.DeepEqual(origin, tt.origin) {
   565  			t.Errorf("#%d: got origin %v; want %v", i, origin, tt.origin)
   566  			continue
   567  		}
   568  	}
   569  }
   570  
   571  func TestCtrlAndData(t *testing.T) {
   572  	once.Do(startServer)
   573  
   574  	c, err := net.Dial("tcp", serverAddr)
   575  	if err != nil {
   576  		t.Fatal(err)
   577  	}
   578  	ws, err := NewClient(newConfig(t, "/ctrldata"), c)
   579  	if err != nil {
   580  		t.Fatal(err)
   581  	}
   582  	defer ws.Close()
   583  
   584  	h := &testCtrlAndDataHandler{hybiFrameHandler: hybiFrameHandler{conn: ws}}
   585  	ws.frameHandler = h
   586  
   587  	b := make([]byte, 128)
   588  	for i := 0; i < 2; i++ {
   589  		data := []byte(fmt.Sprintf("#%d-DATA-FRAME-FROM-CLIENT", i))
   590  		if _, err := ws.Write(data); err != nil {
   591  			t.Fatalf("#%d: %v", i, err)
   592  		}
   593  		var ctrl []byte
   594  		if i%2 != 0 { // with or without payload
   595  			ctrl = []byte(fmt.Sprintf("#%d-CONTROL-FRAME-FROM-CLIENT", i))
   596  		}
   597  		if _, err := h.WritePing(ctrl); err != nil {
   598  			t.Fatalf("#%d: %v", i, err)
   599  		}
   600  		n, err := ws.Read(b)
   601  		if err != nil {
   602  			t.Fatalf("#%d: %v", i, err)
   603  		}
   604  		if !bytes.Equal(b[:n], data) {
   605  			t.Fatalf("#%d: got %v; want %v", i, b[:n], data)
   606  		}
   607  	}
   608  }
   609  
   610  func TestCodec_ReceiveLimited(t *testing.T) {
   611  	const limit = 2048
   612  	var payloads [][]byte
   613  	for _, size := range []int{
   614  		1024,
   615  		2048,
   616  		4096, // receive of this message would be interrupted due to limit
   617  		2048, // this one is to make sure next receive recovers discarding leftovers
   618  	} {
   619  		b := make([]byte, size)
   620  		rand.Read(b)
   621  		payloads = append(payloads, b)
   622  	}
   623  	handlerDone := make(chan struct{})
   624  	limitedHandler := func(ws *Conn) {
   625  		defer close(handlerDone)
   626  		ws.MaxPayloadBytes = limit
   627  		defer ws.Close()
   628  		for i, p := range payloads {
   629  			t.Logf("payload #%d (size %d, exceeds limit: %v)", i, len(p), len(p) > limit)
   630  			var recv []byte
   631  			err := Message.Receive(ws, &recv)
   632  			switch err {
   633  			case nil:
   634  			case ErrFrameTooLarge:
   635  				if len(p) <= limit {
   636  					t.Fatalf("unexpected frame size limit: expected %d bytes of payload having limit at %d", len(p), limit)
   637  				}
   638  				continue
   639  			default:
   640  				t.Fatalf("unexpected error: %v (want either nil or ErrFrameTooLarge)", err)
   641  			}
   642  			if len(recv) > limit {
   643  				t.Fatalf("received %d bytes of payload having limit at %d", len(recv), limit)
   644  			}
   645  			if !bytes.Equal(p, recv) {
   646  				t.Fatalf("received payload differs:\ngot:\t%v\nwant:\t%v", recv, p)
   647  			}
   648  		}
   649  	}
   650  	server := httptest.NewServer(Handler(limitedHandler))
   651  	defer server.CloseClientConnections()
   652  	defer server.Close()
   653  	addr := server.Listener.Addr().String()
   654  	ws, err := Dial("ws://"+addr+"/", "", "http://localhost/")
   655  	if err != nil {
   656  		t.Fatal(err)
   657  	}
   658  	defer ws.Close()
   659  	for i, p := range payloads {
   660  		if err := Message.Send(ws, p); err != nil {
   661  			t.Fatalf("payload #%d (size %d): %v", i, len(p), err)
   662  		}
   663  	}
   664  	<-handlerDone
   665  }