github.com/tmlbl/deis@v1.0.2/logspout/Godeps/_workspace/src/code.google.com/p/go.net/websocket/websocket_test.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 import ( 8 "bytes" 9 "fmt" 10 "io" 11 "log" 12 "net" 13 "net/http" 14 "net/http/httptest" 15 "net/url" 16 "strings" 17 "sync" 18 "testing" 19 ) 20 21 var serverAddr string 22 var once sync.Once 23 24 func echoServer(ws *Conn) { io.Copy(ws, ws) } 25 26 type Count struct { 27 S string 28 N int 29 } 30 31 func countServer(ws *Conn) { 32 for { 33 var count Count 34 err := JSON.Receive(ws, &count) 35 if err != nil { 36 return 37 } 38 count.N++ 39 count.S = strings.Repeat(count.S, count.N) 40 err = JSON.Send(ws, count) 41 if err != nil { 42 return 43 } 44 } 45 } 46 47 func subProtocolHandshake(config *Config, req *http.Request) error { 48 for _, proto := range config.Protocol { 49 if proto == "chat" { 50 config.Protocol = []string{proto} 51 return nil 52 } 53 } 54 return ErrBadWebSocketProtocol 55 } 56 57 func subProtoServer(ws *Conn) { 58 for _, proto := range ws.Config().Protocol { 59 io.WriteString(ws, proto) 60 } 61 } 62 63 func startServer() { 64 http.Handle("/echo", Handler(echoServer)) 65 http.Handle("/count", Handler(countServer)) 66 subproto := Server{ 67 Handshake: subProtocolHandshake, 68 Handler: Handler(subProtoServer), 69 } 70 http.Handle("/subproto", subproto) 71 server := httptest.NewServer(nil) 72 serverAddr = server.Listener.Addr().String() 73 log.Print("Test WebSocket server listening on ", serverAddr) 74 } 75 76 func newConfig(t *testing.T, path string) *Config { 77 config, _ := NewConfig(fmt.Sprintf("ws://%s%s", serverAddr, path), "http://localhost") 78 return config 79 } 80 81 func TestEcho(t *testing.T) { 82 once.Do(startServer) 83 84 // websocket.Dial() 85 client, err := net.Dial("tcp", serverAddr) 86 if err != nil { 87 t.Fatal("dialing", err) 88 } 89 conn, err := NewClient(newConfig(t, "/echo"), client) 90 if err != nil { 91 t.Errorf("WebSocket handshake error: %v", err) 92 return 93 } 94 95 msg := []byte("hello, world\n") 96 if _, err := conn.Write(msg); err != nil { 97 t.Errorf("Write: %v", err) 98 } 99 var actual_msg = make([]byte, 512) 100 n, err := conn.Read(actual_msg) 101 if err != nil { 102 t.Errorf("Read: %v", err) 103 } 104 actual_msg = actual_msg[0:n] 105 if !bytes.Equal(msg, actual_msg) { 106 t.Errorf("Echo: expected %q got %q", msg, actual_msg) 107 } 108 conn.Close() 109 } 110 111 func TestAddr(t *testing.T) { 112 once.Do(startServer) 113 114 // websocket.Dial() 115 client, err := net.Dial("tcp", serverAddr) 116 if err != nil { 117 t.Fatal("dialing", err) 118 } 119 conn, err := NewClient(newConfig(t, "/echo"), client) 120 if err != nil { 121 t.Errorf("WebSocket handshake error: %v", err) 122 return 123 } 124 125 ra := conn.RemoteAddr().String() 126 if !strings.HasPrefix(ra, "ws://") || !strings.HasSuffix(ra, "/echo") { 127 t.Errorf("Bad remote addr: %v", ra) 128 } 129 la := conn.LocalAddr().String() 130 if !strings.HasPrefix(la, "http://") { 131 t.Errorf("Bad local addr: %v", la) 132 } 133 conn.Close() 134 } 135 136 func TestCount(t *testing.T) { 137 once.Do(startServer) 138 139 // websocket.Dial() 140 client, err := net.Dial("tcp", serverAddr) 141 if err != nil { 142 t.Fatal("dialing", err) 143 } 144 conn, err := NewClient(newConfig(t, "/count"), client) 145 if err != nil { 146 t.Errorf("WebSocket handshake error: %v", err) 147 return 148 } 149 150 var count Count 151 count.S = "hello" 152 if err := JSON.Send(conn, count); err != nil { 153 t.Errorf("Write: %v", err) 154 } 155 if err := JSON.Receive(conn, &count); err != nil { 156 t.Errorf("Read: %v", err) 157 } 158 if count.N != 1 { 159 t.Errorf("count: expected %d got %d", 1, count.N) 160 } 161 if count.S != "hello" { 162 t.Errorf("count: expected %q got %q", "hello", count.S) 163 } 164 if err := JSON.Send(conn, count); err != nil { 165 t.Errorf("Write: %v", err) 166 } 167 if err := JSON.Receive(conn, &count); err != nil { 168 t.Errorf("Read: %v", err) 169 } 170 if count.N != 2 { 171 t.Errorf("count: expected %d got %d", 2, count.N) 172 } 173 if count.S != "hellohello" { 174 t.Errorf("count: expected %q got %q", "hellohello", count.S) 175 } 176 conn.Close() 177 } 178 179 func TestWithQuery(t *testing.T) { 180 once.Do(startServer) 181 182 client, err := net.Dial("tcp", serverAddr) 183 if err != nil { 184 t.Fatal("dialing", err) 185 } 186 187 config := newConfig(t, "/echo") 188 config.Location, err = url.ParseRequestURI(fmt.Sprintf("ws://%s/echo?q=v", serverAddr)) 189 if err != nil { 190 t.Fatal("location url", err) 191 } 192 193 ws, err := NewClient(config, client) 194 if err != nil { 195 t.Errorf("WebSocket handshake: %v", err) 196 return 197 } 198 ws.Close() 199 } 200 201 func testWithProtocol(t *testing.T, subproto []string) (string, error) { 202 once.Do(startServer) 203 204 client, err := net.Dial("tcp", serverAddr) 205 if err != nil { 206 t.Fatal("dialing", err) 207 } 208 209 config := newConfig(t, "/subproto") 210 config.Protocol = subproto 211 212 ws, err := NewClient(config, client) 213 if err != nil { 214 return "", err 215 } 216 msg := make([]byte, 16) 217 n, err := ws.Read(msg) 218 if err != nil { 219 return "", err 220 } 221 ws.Close() 222 return string(msg[:n]), nil 223 } 224 225 func TestWithProtocol(t *testing.T) { 226 proto, err := testWithProtocol(t, []string{"chat"}) 227 if err != nil { 228 t.Errorf("SubProto: unexpected error: %v", err) 229 } 230 if proto != "chat" { 231 t.Errorf("SubProto: expected %q, got %q", "chat", proto) 232 } 233 } 234 235 func TestWithTwoProtocol(t *testing.T) { 236 proto, err := testWithProtocol(t, []string{"test", "chat"}) 237 if err != nil { 238 t.Errorf("SubProto: unexpected error: %v", err) 239 } 240 if proto != "chat" { 241 t.Errorf("SubProto: expected %q, got %q", "chat", proto) 242 } 243 } 244 245 func TestWithBadProtocol(t *testing.T) { 246 _, err := testWithProtocol(t, []string{"test"}) 247 if err != ErrBadStatus { 248 t.Errorf("SubProto: expected %v, got %v", ErrBadStatus, err) 249 } 250 } 251 252 func TestHTTP(t *testing.T) { 253 once.Do(startServer) 254 255 // If the client did not send a handshake that matches the protocol 256 // specification, the server MUST return an HTTP response with an 257 // appropriate error code (such as 400 Bad Request) 258 resp, err := http.Get(fmt.Sprintf("http://%s/echo", serverAddr)) 259 if err != nil { 260 t.Errorf("Get: error %#v", err) 261 return 262 } 263 if resp == nil { 264 t.Error("Get: resp is null") 265 return 266 } 267 if resp.StatusCode != http.StatusBadRequest { 268 t.Errorf("Get: expected %q got %q", http.StatusBadRequest, resp.StatusCode) 269 } 270 } 271 272 func TestTrailingSpaces(t *testing.T) { 273 // http://code.google.com/p/go/issues/detail?id=955 274 // The last runs of this create keys with trailing spaces that should not be 275 // generated by the client. 276 once.Do(startServer) 277 config := newConfig(t, "/echo") 278 for i := 0; i < 30; i++ { 279 // body 280 ws, err := DialConfig(config) 281 if err != nil { 282 t.Errorf("Dial #%d failed: %v", i, err) 283 break 284 } 285 ws.Close() 286 } 287 } 288 289 func TestDialConfigBadVersion(t *testing.T) { 290 once.Do(startServer) 291 config := newConfig(t, "/echo") 292 config.Version = 1234 293 294 _, err := DialConfig(config) 295 296 if dialerr, ok := err.(*DialError); ok { 297 if dialerr.Err != ErrBadProtocolVersion { 298 t.Errorf("dial expected err %q but got %q", ErrBadProtocolVersion, dialerr.Err) 299 } 300 } 301 } 302 303 func TestSmallBuffer(t *testing.T) { 304 // http://code.google.com/p/go/issues/detail?id=1145 305 // Read should be able to handle reading a fragment of a frame. 306 once.Do(startServer) 307 308 // websocket.Dial() 309 client, err := net.Dial("tcp", serverAddr) 310 if err != nil { 311 t.Fatal("dialing", err) 312 } 313 conn, err := NewClient(newConfig(t, "/echo"), client) 314 if err != nil { 315 t.Errorf("WebSocket handshake error: %v", err) 316 return 317 } 318 319 msg := []byte("hello, world\n") 320 if _, err := conn.Write(msg); err != nil { 321 t.Errorf("Write: %v", err) 322 } 323 var small_msg = make([]byte, 8) 324 n, err := conn.Read(small_msg) 325 if err != nil { 326 t.Errorf("Read: %v", err) 327 } 328 if !bytes.Equal(msg[:len(small_msg)], small_msg) { 329 t.Errorf("Echo: expected %q got %q", msg[:len(small_msg)], small_msg) 330 } 331 var second_msg = make([]byte, len(msg)) 332 n, err = conn.Read(second_msg) 333 if err != nil { 334 t.Errorf("Read: %v", err) 335 } 336 second_msg = second_msg[0:n] 337 if !bytes.Equal(msg[len(small_msg):], second_msg) { 338 t.Errorf("Echo: expected %q got %q", msg[len(small_msg):], second_msg) 339 } 340 conn.Close() 341 }