github.com/deis/deis@v1.13.5-0.20170519182049-1d9e59fbdbfc/Godeps/_workspace/src/golang.org/x/net/websocket/websocket_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"net/url"
    16  	"strings"
    17  	"sync"
    18  	"testing"
    19  	"time"
    20  )
    21  
    22  var serverAddr string
    23  var once sync.Once
    24  
    25  func echoServer(ws *Conn) { io.Copy(ws, ws) }
    26  
    27  type Count struct {
    28  	S string
    29  	N int
    30  }
    31  
    32  func countServer(ws *Conn) {
    33  	for {
    34  		var count Count
    35  		err := JSON.Receive(ws, &count)
    36  		if err != nil {
    37  			return
    38  		}
    39  		count.N++
    40  		count.S = strings.Repeat(count.S, count.N)
    41  		err = JSON.Send(ws, count)
    42  		if err != nil {
    43  			return
    44  		}
    45  	}
    46  }
    47  
    48  func subProtocolHandshake(config *Config, req *http.Request) error {
    49  	for _, proto := range config.Protocol {
    50  		if proto == "chat" {
    51  			config.Protocol = []string{proto}
    52  			return nil
    53  		}
    54  	}
    55  	return ErrBadWebSocketProtocol
    56  }
    57  
    58  func subProtoServer(ws *Conn) {
    59  	for _, proto := range ws.Config().Protocol {
    60  		io.WriteString(ws, proto)
    61  	}
    62  }
    63  
    64  func startServer() {
    65  	http.Handle("/echo", Handler(echoServer))
    66  	http.Handle("/count", Handler(countServer))
    67  	subproto := Server{
    68  		Handshake: subProtocolHandshake,
    69  		Handler:   Handler(subProtoServer),
    70  	}
    71  	http.Handle("/subproto", subproto)
    72  	server := httptest.NewServer(nil)
    73  	serverAddr = server.Listener.Addr().String()
    74  	log.Print("Test WebSocket server listening on ", serverAddr)
    75  }
    76  
    77  func newConfig(t *testing.T, path string) *Config {
    78  	config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
    79  	return config
    80  }
    81  
    82  func TestEcho(t *testing.T) {
    83  	once.Do(startServer)
    84  
    85  	// websocket.Dial()
    86  	client, err := net.Dial("tcp", serverAddr)
    87  	if err != nil {
    88  		t.Fatal("dialing", err)
    89  	}
    90  	conn, err := NewClient(newConfig(t, "/echo"), client)
    91  	if err != nil {
    92  		t.Errorf("WebSocket handshake error: %v", err)
    93  		return
    94  	}
    95  
    96  	msg := []byte("hello, world\n")
    97  	if _, err := conn.Write(msg); err != nil {
    98  		t.Errorf("Write: %v", err)
    99  	}
   100  	var actual_msg = make([]byte, 512)
   101  	n, err := conn.Read(actual_msg)
   102  	if err != nil {
   103  		t.Errorf("Read: %v", err)
   104  	}
   105  	actual_msg = actual_msg[0:n]
   106  	if !bytes.Equal(msg, actual_msg) {
   107  		t.Errorf("Echo: expected %q got %q", msg, actual_msg)
   108  	}
   109  	conn.Close()
   110  }
   111  
   112  func TestAddr(t *testing.T) {
   113  	once.Do(startServer)
   114  
   115  	// websocket.Dial()
   116  	client, err := net.Dial("tcp", serverAddr)
   117  	if err != nil {
   118  		t.Fatal("dialing", err)
   119  	}
   120  	conn, err := NewClient(newConfig(t, "/echo"), client)
   121  	if err != nil {
   122  		t.Errorf("WebSocket handshake error: %v", err)
   123  		return
   124  	}
   125  
   126  	ra := conn.RemoteAddr().String()
   127  	if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
   128  		t.Errorf("Bad remote addr: %v", ra)
   129  	}
   130  	la := conn.LocalAddr().String()
   131  	if !strings.HasPrefix(la, "http://") {
   132  		t.Errorf("Bad local addr: %v", la)
   133  	}
   134  	conn.Close()
   135  }
   136  
   137  func TestCount(t *testing.T) {
   138  	once.Do(startServer)
   139  
   140  	// websocket.Dial()
   141  	client, err := net.Dial("tcp", serverAddr)
   142  	if err != nil {
   143  		t.Fatal("dialing", err)
   144  	}
   145  	conn, err := NewClient(newConfig(t, "/count"), client)
   146  	if err != nil {
   147  		t.Errorf("WebSocket handshake error: %v", err)
   148  		return
   149  	}
   150  
   151  	var count Count
   152  	count.S = "hello"
   153  	if err := JSON.Send(conn, count); err != nil {
   154  		t.Errorf("Write: %v", err)
   155  	}
   156  	if err := JSON.Receive(conn, &count); err != nil {
   157  		t.Errorf("Read: %v", err)
   158  	}
   159  	if count.N != 1 {
   160  		t.Errorf("count: expected %d got %d", 1, count.N)
   161  	}
   162  	if count.S != "hello" {
   163  		t.Errorf("count: expected %q got %q", "hello", count.S)
   164  	}
   165  	if err := JSON.Send(conn, count); err != nil {
   166  		t.Errorf("Write: %v", err)
   167  	}
   168  	if err := JSON.Receive(conn, &count); err != nil {
   169  		t.Errorf("Read: %v", err)
   170  	}
   171  	if count.N != 2 {
   172  		t.Errorf("count: expected %d got %d", 2, count.N)
   173  	}
   174  	if count.S != "hellohello" {
   175  		t.Errorf("count: expected %q got %q", "hellohello", count.S)
   176  	}
   177  	conn.Close()
   178  }
   179  
   180  func TestWithQuery(t *testing.T) {
   181  	once.Do(startServer)
   182  
   183  	client, err := net.Dial("tcp", serverAddr)
   184  	if err != nil {
   185  		t.Fatal("dialing", err)
   186  	}
   187  
   188  	config := newConfig(t, "/echo")
   189  	config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
   190  	if err != nil {
   191  		t.Fatal("location url", err)
   192  	}
   193  
   194  	ws, err := NewClient(config, client)
   195  	if err != nil {
   196  		t.Errorf("WebSocket handshake: %v", err)
   197  		return
   198  	}
   199  	ws.Close()
   200  }
   201  
   202  func testWithProtocol(t *testing.T, subproto []string) (string, error) {
   203  	once.Do(startServer)
   204  
   205  	client, err := net.Dial("tcp", serverAddr)
   206  	if err != nil {
   207  		t.Fatal("dialing", err)
   208  	}
   209  
   210  	config := newConfig(t, "/subproto")
   211  	config.Protocol = subproto
   212  
   213  	ws, err := NewClient(config, client)
   214  	if err != nil {
   215  		return "", err
   216  	}
   217  	msg := make([]byte, 16)
   218  	n, err := ws.Read(msg)
   219  	if err != nil {
   220  		return "", err
   221  	}
   222  	ws.Close()
   223  	return string(msg[:n]), nil
   224  }
   225  
   226  func TestWithProtocol(t *testing.T) {
   227  	proto, err := testWithProtocol(t, []string{"chat"})
   228  	if err != nil {
   229  		t.Errorf("SubProto: unexpected error: %v", err)
   230  	}
   231  	if proto != "chat" {
   232  		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
   233  	}
   234  }
   235  
   236  func TestWithTwoProtocol(t *testing.T) {
   237  	proto, err := testWithProtocol(t, []string{"test", "chat"})
   238  	if err != nil {
   239  		t.Errorf("SubProto: unexpected error: %v", err)
   240  	}
   241  	if proto != "chat" {
   242  		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
   243  	}
   244  }
   245  
   246  func TestWithBadProtocol(t *testing.T) {
   247  	_, err := testWithProtocol(t, []string{"test"})
   248  	if err != ErrBadStatus {
   249  		t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
   250  	}
   251  }
   252  
   253  func TestHTTP(t *testing.T) {
   254  	once.Do(startServer)
   255  
   256  	// If the client did not send a handshake that matches the protocol
   257  	// specification, the server MUST return an HTTP response with an
   258  	// appropriate error code (such as 400 Bad Request)
   259  	resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
   260  	if err != nil {
   261  		t.Errorf("Get: error %#v", err)
   262  		return
   263  	}
   264  	if resp == nil {
   265  		t.Error("Get: resp is null")
   266  		return
   267  	}
   268  	if resp.StatusCode != http.StatusBadRequest {
   269  		t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
   270  	}
   271  }
   272  
   273  func TestTrailingSpaces(t *testing.T) {
   274  	// http://code.google.com/p/go/issues/detail?id=955
   275  	// The last runs of this create keys with trailing spaces that should not be
   276  	// generated by the client.
   277  	once.Do(startServer)
   278  	config := newConfig(t, "/echo")
   279  	for i := 0; i < 30; i++ {
   280  		// body
   281  		ws, err := DialConfig(config)
   282  		if err != nil {
   283  			t.Errorf("Dial #%d failed: %v", i, err)
   284  			break
   285  		}
   286  		ws.Close()
   287  	}
   288  }
   289  
   290  func TestDialConfigBadVersion(t *testing.T) {
   291  	once.Do(startServer)
   292  	config := newConfig(t, "/echo")
   293  	config.Version = 1234
   294  
   295  	_, err := DialConfig(config)
   296  
   297  	if dialerr, ok := err.(*DialError); ok {
   298  		if dialerr.Err != ErrBadProtocolVersion {
   299  			t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
   300  		}
   301  	}
   302  }
   303  
   304  func TestSmallBuffer(t *testing.T) {
   305  	// http://code.google.com/p/go/issues/detail?id=1145
   306  	// Read should be able to handle reading a fragment of a frame.
   307  	once.Do(startServer)
   308  
   309  	// websocket.Dial()
   310  	client, err := net.Dial("tcp", serverAddr)
   311  	if err != nil {
   312  		t.Fatal("dialing", err)
   313  	}
   314  	conn, err := NewClient(newConfig(t, "/echo"), client)
   315  	if err != nil {
   316  		t.Errorf("WebSocket handshake error: %v", err)
   317  		return
   318  	}
   319  
   320  	msg := []byte("hello, world\n")
   321  	if _, err := conn.Write(msg); err != nil {
   322  		t.Errorf("Write: %v", err)
   323  	}
   324  	var small_msg = make([]byte, 8)
   325  	n, err := conn.Read(small_msg)
   326  	if err != nil {
   327  		t.Errorf("Read: %v", err)
   328  	}
   329  	if !bytes.Equal(msg[:len(small_msg)], small_msg) {
   330  		t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
   331  	}
   332  	var second_msg = make([]byte, len(msg))
   333  	n, err = conn.Read(second_msg)
   334  	if err != nil {
   335  		t.Errorf("Read: %v", err)
   336  	}
   337  	second_msg = second_msg[0:n]
   338  	if !bytes.Equal(msg[len(small_msg):], second_msg) {
   339  		t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
   340  	}
   341  	conn.Close()
   342  }
   343  
   344  var parseAuthorityTests = []struct {
   345  	in  *url.URL
   346  	out string
   347  }{
   348  	{
   349  		&url.URL{
   350  			Scheme: "ws",
   351  			Host:   "www.google.com",
   352  		},
   353  		"www.google.com:80",
   354  	},
   355  	{
   356  		&url.URL{
   357  			Scheme: "wss",
   358  			Host:   "www.google.com",
   359  		},
   360  		"www.google.com:443",
   361  	},
   362  	{
   363  		&url.URL{
   364  			Scheme: "ws",
   365  			Host:   "www.google.com:80",
   366  		},
   367  		"www.google.com:80",
   368  	},
   369  	{
   370  		&url.URL{
   371  			Scheme: "wss",
   372  			Host:   "www.google.com:443",
   373  		},
   374  		"www.google.com:443",
   375  	},
   376  	// some invalid ones for parseAuthority. parseAuthority doesn't
   377  	// concern itself with the scheme unless it actually knows about it
   378  	{
   379  		&url.URL{
   380  			Scheme: "http",
   381  			Host:   "www.google.com",
   382  		},
   383  		"www.google.com",
   384  	},
   385  	{
   386  		&url.URL{
   387  			Scheme: "http",
   388  			Host:   "www.google.com:80",
   389  		},
   390  		"www.google.com:80",
   391  	},
   392  	{
   393  		&url.URL{
   394  			Scheme: "asdf",
   395  			Host:   "127.0.0.1",
   396  		},
   397  		"127.0.0.1",
   398  	},
   399  	{
   400  		&url.URL{
   401  			Scheme: "asdf",
   402  			Host:   "www.google.com",
   403  		},
   404  		"www.google.com",
   405  	},
   406  }
   407  
   408  func TestParseAuthority(t *testing.T) {
   409  	for _, tt := range parseAuthorityTests {
   410  		out := parseAuthority(tt.in)
   411  		if out != tt.out {
   412  			t.Errorf("got %v; want %v", out, tt.out)
   413  		}
   414  	}
   415  }
   416  
   417  type closerConn struct {
   418  	net.Conn
   419  	closed int // count of the number of times Close was called
   420  }
   421  
   422  func (c *closerConn) Close() error {
   423  	c.closed++
   424  	return c.Conn.Close()
   425  }
   426  
   427  func TestClose(t *testing.T) {
   428  	once.Do(startServer)
   429  
   430  	conn, err := net.Dial("tcp", serverAddr)
   431  	if err != nil {
   432  		t.Fatal("dialing", err)
   433  	}
   434  
   435  	cc := closerConn{Conn: conn}
   436  
   437  	client, err := NewClient(newConfig(t, "/echo"), &cc)
   438  	if err != nil {
   439  		t.Fatalf("WebSocket handshake: %v", err)
   440  	}
   441  
   442  	// set the deadline to ten minutes ago, which will have expired by the time
   443  	// client.Close sends the close status frame.
   444  	conn.SetDeadline(time.Now().Add(-10 * time.Minute))
   445  
   446  	if err := client.Close(); err == nil {
   447  		t.Errorf("ws.Close(): expected error, got %v", err)
   448  	}
   449  	if cc.closed < 1 {
   450  		t.Fatalf("ws.Close(): expected underlying ws.rwc.Close to be called > 0 times, got: %v", cc.closed)
   451  	}
   452  }