github.com/tmlbl/deis@v1.0.2/logspout/Godeps/_workspace/src/code.google.com/p/go.net/websocket/websocket_test.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bytes"
     9  	"fmt"
    10  	"io"
    11  	"log"
    12  	"net"
    13  	"net/http"
    14  	"net/http/httptest"
    15  	"net/url"
    16  	"strings"
    17  	"sync"
    18  	"testing"
    19  )
    20  
    21  var serverAddr string
    22  var once sync.Once
    23  
    24  func echoServer(ws *Conn) { io.Copy(ws, ws) }
    25  
    26  type Count struct {
    27  	S string
    28  	N int
    29  }
    30  
    31  func countServer(ws *Conn) {
    32  	for {
    33  		var count Count
    34  		err := JSON.Receive(ws, &count)
    35  		if err != nil {
    36  			return
    37  		}
    38  		count.N++
    39  		count.S = strings.Repeat(count.S, count.N)
    40  		err = JSON.Send(ws, count)
    41  		if err != nil {
    42  			return
    43  		}
    44  	}
    45  }
    46  
    47  func subProtocolHandshake(config *Config, req *http.Request) error {
    48  	for _, proto := range config.Protocol {
    49  		if proto == "chat" {
    50  			config.Protocol = []string{proto}
    51  			return nil
    52  		}
    53  	}
    54  	return ErrBadWebSocketProtocol
    55  }
    56  
    57  func subProtoServer(ws *Conn) {
    58  	for _, proto := range ws.Config().Protocol {
    59  		io.WriteString(ws, proto)
    60  	}
    61  }
    62  
    63  func startServer() {
    64  	http.Handle("/echo", Handler(echoServer))
    65  	http.Handle("/count", Handler(countServer))
    66  	subproto := Server{
    67  		Handshake: subProtocolHandshake,
    68  		Handler:   Handler(subProtoServer),
    69  	}
    70  	http.Handle("/subproto", subproto)
    71  	server := httptest.NewServer(nil)
    72  	serverAddr = server.Listener.Addr().String()
    73  	log.Print("Test WebSocket server listening on ", serverAddr)
    74  }
    75  
    76  func newConfig(t *testing.T, path string) *Config {
    77  	config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost")
    78  	return config
    79  }
    80  
    81  func TestEcho(t *testing.T) {
    82  	once.Do(startServer)
    83  
    84  	// websocket.Dial()
    85  	client, err := net.Dial("tcp", serverAddr)
    86  	if err != nil {
    87  		t.Fatal("dialing", err)
    88  	}
    89  	conn, err := NewClient(newConfig(t, "/echo"), client)
    90  	if err != nil {
    91  		t.Errorf("WebSocket handshake error: %v", err)
    92  		return
    93  	}
    94  
    95  	msg := []byte("hello, world\n")
    96  	if _, err := conn.Write(msg); err != nil {
    97  		t.Errorf("Write: %v", err)
    98  	}
    99  	var actual_msg = make([]byte, 512)
   100  	n, err := conn.Read(actual_msg)
   101  	if err != nil {
   102  		t.Errorf("Read: %v", err)
   103  	}
   104  	actual_msg = actual_msg[0:n]
   105  	if !bytes.Equal(msg, actual_msg) {
   106  		t.Errorf("Echo: expected %q got %q", msg, actual_msg)
   107  	}
   108  	conn.Close()
   109  }
   110  
   111  func TestAddr(t *testing.T) {
   112  	once.Do(startServer)
   113  
   114  	// websocket.Dial()
   115  	client, err := net.Dial("tcp", serverAddr)
   116  	if err != nil {
   117  		t.Fatal("dialing", err)
   118  	}
   119  	conn, err := NewClient(newConfig(t, "/echo"), client)
   120  	if err != nil {
   121  		t.Errorf("WebSocket handshake error: %v", err)
   122  		return
   123  	}
   124  
   125  	ra := conn.RemoteAddr().String()
   126  	if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") {
   127  		t.Errorf("Bad remote addr: %v", ra)
   128  	}
   129  	la := conn.LocalAddr().String()
   130  	if !strings.HasPrefix(la, "http://") {
   131  		t.Errorf("Bad local addr: %v", la)
   132  	}
   133  	conn.Close()
   134  }
   135  
   136  func TestCount(t *testing.T) {
   137  	once.Do(startServer)
   138  
   139  	// websocket.Dial()
   140  	client, err := net.Dial("tcp", serverAddr)
   141  	if err != nil {
   142  		t.Fatal("dialing", err)
   143  	}
   144  	conn, err := NewClient(newConfig(t, "/count"), client)
   145  	if err != nil {
   146  		t.Errorf("WebSocket handshake error: %v", err)
   147  		return
   148  	}
   149  
   150  	var count Count
   151  	count.S = "hello"
   152  	if err := JSON.Send(conn, count); err != nil {
   153  		t.Errorf("Write: %v", err)
   154  	}
   155  	if err := JSON.Receive(conn, &count); err != nil {
   156  		t.Errorf("Read: %v", err)
   157  	}
   158  	if count.N != 1 {
   159  		t.Errorf("count: expected %d got %d", 1, count.N)
   160  	}
   161  	if count.S != "hello" {
   162  		t.Errorf("count: expected %q got %q", "hello", count.S)
   163  	}
   164  	if err := JSON.Send(conn, count); err != nil {
   165  		t.Errorf("Write: %v", err)
   166  	}
   167  	if err := JSON.Receive(conn, &count); err != nil {
   168  		t.Errorf("Read: %v", err)
   169  	}
   170  	if count.N != 2 {
   171  		t.Errorf("count: expected %d got %d", 2, count.N)
   172  	}
   173  	if count.S != "hellohello" {
   174  		t.Errorf("count: expected %q got %q", "hellohello", count.S)
   175  	}
   176  	conn.Close()
   177  }
   178  
   179  func TestWithQuery(t *testing.T) {
   180  	once.Do(startServer)
   181  
   182  	client, err := net.Dial("tcp", serverAddr)
   183  	if err != nil {
   184  		t.Fatal("dialing", err)
   185  	}
   186  
   187  	config := newConfig(t, "/echo")
   188  	config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr))
   189  	if err != nil {
   190  		t.Fatal("location url", err)
   191  	}
   192  
   193  	ws, err := NewClient(config, client)
   194  	if err != nil {
   195  		t.Errorf("WebSocket handshake: %v", err)
   196  		return
   197  	}
   198  	ws.Close()
   199  }
   200  
   201  func testWithProtocol(t *testing.T, subproto []string) (string, error) {
   202  	once.Do(startServer)
   203  
   204  	client, err := net.Dial("tcp", serverAddr)
   205  	if err != nil {
   206  		t.Fatal("dialing", err)
   207  	}
   208  
   209  	config := newConfig(t, "/subproto")
   210  	config.Protocol = subproto
   211  
   212  	ws, err := NewClient(config, client)
   213  	if err != nil {
   214  		return "", err
   215  	}
   216  	msg := make([]byte, 16)
   217  	n, err := ws.Read(msg)
   218  	if err != nil {
   219  		return "", err
   220  	}
   221  	ws.Close()
   222  	return string(msg[:n]), nil
   223  }
   224  
   225  func TestWithProtocol(t *testing.T) {
   226  	proto, err := testWithProtocol(t, []string{"chat"})
   227  	if err != nil {
   228  		t.Errorf("SubProto: unexpected error: %v", err)
   229  	}
   230  	if proto != "chat" {
   231  		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
   232  	}
   233  }
   234  
   235  func TestWithTwoProtocol(t *testing.T) {
   236  	proto, err := testWithProtocol(t, []string{"test", "chat"})
   237  	if err != nil {
   238  		t.Errorf("SubProto: unexpected error: %v", err)
   239  	}
   240  	if proto != "chat" {
   241  		t.Errorf("SubProto: expected %q, got %q", "chat", proto)
   242  	}
   243  }
   244  
   245  func TestWithBadProtocol(t *testing.T) {
   246  	_, err := testWithProtocol(t, []string{"test"})
   247  	if err != ErrBadStatus {
   248  		t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err)
   249  	}
   250  }
   251  
   252  func TestHTTP(t *testing.T) {
   253  	once.Do(startServer)
   254  
   255  	// If the client did not send a handshake that matches the protocol
   256  	// specification, the server MUST return an HTTP response with an
   257  	// appropriate error code (such as 400 Bad Request)
   258  	resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr))
   259  	if err != nil {
   260  		t.Errorf("Get: error %#v", err)
   261  		return
   262  	}
   263  	if resp == nil {
   264  		t.Error("Get: resp is null")
   265  		return
   266  	}
   267  	if resp.StatusCode != http.StatusBadRequest {
   268  		t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode)
   269  	}
   270  }
   271  
   272  func TestTrailingSpaces(t *testing.T) {
   273  	// http://code.google.com/p/go/issues/detail?id=955
   274  	// The last runs of this create keys with trailing spaces that should not be
   275  	// generated by the client.
   276  	once.Do(startServer)
   277  	config := newConfig(t, "/echo")
   278  	for i := 0; i < 30; i++ {
   279  		// body
   280  		ws, err := DialConfig(config)
   281  		if err != nil {
   282  			t.Errorf("Dial #%d failed: %v", i, err)
   283  			break
   284  		}
   285  		ws.Close()
   286  	}
   287  }
   288  
   289  func TestDialConfigBadVersion(t *testing.T) {
   290  	once.Do(startServer)
   291  	config := newConfig(t, "/echo")
   292  	config.Version = 1234
   293  
   294  	_, err := DialConfig(config)
   295  
   296  	if dialerr, ok := err.(*DialError); ok {
   297  		if dialerr.Err != ErrBadProtocolVersion {
   298  			t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err)
   299  		}
   300  	}
   301  }
   302  
   303  func TestSmallBuffer(t *testing.T) {
   304  	// http://code.google.com/p/go/issues/detail?id=1145
   305  	// Read should be able to handle reading a fragment of a frame.
   306  	once.Do(startServer)
   307  
   308  	// websocket.Dial()
   309  	client, err := net.Dial("tcp", serverAddr)
   310  	if err != nil {
   311  		t.Fatal("dialing", err)
   312  	}
   313  	conn, err := NewClient(newConfig(t, "/echo"), client)
   314  	if err != nil {
   315  		t.Errorf("WebSocket handshake error: %v", err)
   316  		return
   317  	}
   318  
   319  	msg := []byte("hello, world\n")
   320  	if _, err := conn.Write(msg); err != nil {
   321  		t.Errorf("Write: %v", err)
   322  	}
   323  	var small_msg = make([]byte, 8)
   324  	n, err := conn.Read(small_msg)
   325  	if err != nil {
   326  		t.Errorf("Read: %v", err)
   327  	}
   328  	if !bytes.Equal(msg[:len(small_msg)], small_msg) {
   329  		t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg)
   330  	}
   331  	var second_msg = make([]byte, len(msg))
   332  	n, err = conn.Read(second_msg)
   333  	if err != nil {
   334  		t.Errorf("Read: %v", err)
   335  	}
   336  	second_msg = second_msg[0:n]
   337  	if !bytes.Equal(msg[len(small_msg):], second_msg) {
   338  		t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg)
   339  	}
   340  	conn.Close()
   341  }