github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/gorilla/websocket/client_server_test.go (about) 1 // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 import ( 8 "crypto/tls" 9 "crypto/x509" 10 "encoding/base64" 11 "io" 12 "io/ioutil" 13 "net" 14 "net/http" 15 "net/http/httptest" 16 "net/url" 17 "reflect" 18 "strings" 19 "testing" 20 "time" 21 ) 22 23 var cstUpgrader = Upgrader{ 24 Subprotocols: []string{"p0", "p1"}, 25 ReadBufferSize: 1024, 26 WriteBufferSize: 1024, 27 Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) { 28 http.Error(w, reason.Error(), status) 29 }, 30 } 31 32 var cstDialer = Dialer{ 33 Subprotocols: []string{"p1", "p2"}, 34 ReadBufferSize: 1024, 35 WriteBufferSize: 1024, 36 } 37 38 type cstHandler struct{ *testing.T } 39 40 type cstServer struct { 41 *httptest.Server 42 URL string 43 } 44 45 const ( 46 cstPath = "/a/b" 47 cstRawQuery = "x=y" 48 cstRequestURI = cstPath + "?" + cstRawQuery 49 ) 50 51 func newServer(t *testing.T) *cstServer { 52 var s cstServer 53 s.Server = httptest.NewServer(cstHandler{t}) 54 s.Server.URL += cstRequestURI 55 s.URL = makeWsProto(s.Server.URL) 56 return &s 57 } 58 59 func newTLSServer(t *testing.T) *cstServer { 60 var s cstServer 61 s.Server = httptest.NewTLSServer(cstHandler{t}) 62 s.Server.URL += cstRequestURI 63 s.URL = makeWsProto(s.Server.URL) 64 return &s 65 } 66 67 func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { 68 if r.URL.Path != cstPath { 69 t.Logf("path=%v, want %v", r.URL.Path, cstPath) 70 http.Error(w, "bad path", 400) 71 return 72 } 73 if r.URL.RawQuery != cstRawQuery { 74 t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) 75 http.Error(w, "bad path", 400) 76 return 77 } 78 subprotos := Subprotocols(r) 79 if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) { 80 t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols) 81 http.Error(w, "bad protocol", 400) 82 return 83 } 84 ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"self.Session.D=1234"}}) 85 if err != nil { 86 t.Logf("Upgrade: %v", err) 87 return 88 } 89 defer ws.Close() 90 91 if ws.Subprotocol() != "p1" { 92 t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol()) 93 ws.Close() 94 return 95 } 96 op, rd, err := ws.NextReader() 97 if err != nil { 98 t.Logf("NextReader: %v", err) 99 return 100 } 101 wr, err := ws.NextWriter(op) 102 if err != nil { 103 t.Logf("NextWriter: %v", err) 104 return 105 } 106 if _, err = io.Copy(wr, rd); err != nil { 107 t.Logf("NextWriter: %v", err) 108 return 109 } 110 if err := wr.Close(); err != nil { 111 t.Logf("Close: %v", err) 112 return 113 } 114 } 115 116 func makeWsProto(s string) string { 117 return "ws" + strings.TrimPrefix(s, "http") 118 } 119 120 func sendRecv(t *testing.T, ws *Conn) { 121 const message = "Hello World!" 122 if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { 123 t.Fatalf("SetWriteDeadline: %v", err) 124 } 125 if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil { 126 t.Fatalf("WriteMessage: %v", err) 127 } 128 if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil { 129 t.Fatalf("SetReadDeadline: %v", err) 130 } 131 _, p, err := ws.ReadMessage() 132 if err != nil { 133 t.Fatalf("ReadMessage: %v", err) 134 } 135 if string(p) != message { 136 t.Fatalf("message=%s, want %s", p, message) 137 } 138 } 139 140 func TestProxyDial(t *testing.T) { 141 142 s := newServer(t) 143 defer s.Close() 144 145 surl, _ := url.Parse(s.URL) 146 147 cstDialer.Proxy = http.ProxyURL(surl) 148 149 connect := false 150 origHandler := s.Server.Config.Handler 151 152 // Capture the request Host header. 153 s.Server.Config.Handler = http.HandlerFunc( 154 func(w http.ResponseWriter, r *http.Request) { 155 if r.Method == "CONNECT" { 156 connect = true 157 w.WriteHeader(200) 158 return 159 } 160 161 if !connect { 162 t.Log("connect not recieved") 163 http.Error(w, "connect not recieved", 405) 164 return 165 } 166 origHandler.ServeHTTP(w, r) 167 }) 168 169 ws, _, err := cstDialer.Dial(s.URL, nil) 170 if err != nil { 171 t.Fatalf("Dial: %v", err) 172 } 173 defer ws.Close() 174 sendRecv(t, ws) 175 176 cstDialer.Proxy = http.ProxyFromEnvironment 177 } 178 179 func TestProxyAuthorizationDial(t *testing.T) { 180 s := newServer(t) 181 defer s.Close() 182 183 surl, _ := url.Parse(s.URL) 184 surl.User = url.UserPassword("username", "password") 185 cstDialer.Proxy = http.ProxyURL(surl) 186 187 connect := false 188 origHandler := s.Server.Config.Handler 189 190 // Capture the request Host header. 191 s.Server.Config.Handler = http.HandlerFunc( 192 func(w http.ResponseWriter, r *http.Request) { 193 proxyAuth := r.Header.Get("Proxy-Authorization") 194 expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password")) 195 if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth { 196 connect = true 197 w.WriteHeader(200) 198 return 199 } 200 201 if !connect { 202 t.Log("connect with proxy authorization not recieved") 203 http.Error(w, "connect with proxy authorization not recieved", 405) 204 return 205 } 206 origHandler.ServeHTTP(w, r) 207 }) 208 209 ws, _, err := cstDialer.Dial(s.URL, nil) 210 if err != nil { 211 t.Fatalf("Dial: %v", err) 212 } 213 defer ws.Close() 214 sendRecv(t, ws) 215 216 cstDialer.Proxy = http.ProxyFromEnvironment 217 } 218 219 func TestDial(t *testing.T) { 220 s := newServer(t) 221 defer s.Close() 222 223 ws, _, err := cstDialer.Dial(s.URL, nil) 224 if err != nil { 225 t.Fatalf("Dial: %v", err) 226 } 227 defer ws.Close() 228 sendRecv(t, ws) 229 } 230 231 func TestDialTLS(t *testing.T) { 232 s := newTLSServer(t) 233 defer s.Close() 234 235 certs := x509.NewCertPool() 236 for _, c := range s.TLS.Certificates { 237 roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1]) 238 if err != nil { 239 t.Fatalf("error parsing server's root cert: %v", err) 240 } 241 for _, root := range roots { 242 certs.AddCert(root) 243 } 244 } 245 246 u, _ := url.Parse(s.URL) 247 d := cstDialer 248 d.NetDial = func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) } 249 d.TLSClientConfig = &tls.Config{RootCAs: certs} 250 ws, _, err := d.Dial("wss://example.com"+cstRequestURI, nil) 251 if err != nil { 252 t.Fatalf("Dial: %v", err) 253 } 254 defer ws.Close() 255 sendRecv(t, ws) 256 } 257 258 func xTestDialTLSBadCert(t *testing.T) { 259 // This test is deactivated because of noisy logging from the net/http package. 260 s := newTLSServer(t) 261 defer s.Close() 262 263 ws, _, err := cstDialer.Dial(s.URL, nil) 264 if err == nil { 265 ws.Close() 266 t.Fatalf("Dial: nil") 267 } 268 } 269 270 func xTestDialTLSNoVerify(t *testing.T) { 271 s := newTLSServer(t) 272 defer s.Close() 273 274 d := cstDialer 275 d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true} 276 ws, _, err := d.Dial(s.URL, nil) 277 if err != nil { 278 t.Fatalf("Dial: %v", err) 279 } 280 defer ws.Close() 281 sendRecv(t, ws) 282 } 283 284 func TestDialTimeout(t *testing.T) { 285 s := newServer(t) 286 defer s.Close() 287 288 d := cstDialer 289 d.HandshakeTimeout = -1 290 ws, _, err := d.Dial(s.URL, nil) 291 if err == nil { 292 ws.Close() 293 t.Fatalf("Dial: nil") 294 } 295 } 296 297 func TestDialBadScheme(t *testing.T) { 298 s := newServer(t) 299 defer s.Close() 300 301 ws, _, err := cstDialer.Dial(s.Server.URL, nil) 302 if err == nil { 303 ws.Close() 304 t.Fatalf("Dial: nil") 305 } 306 } 307 308 func TestDialBadOrigin(t *testing.T) { 309 s := newServer(t) 310 defer s.Close() 311 312 ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) 313 if err == nil { 314 ws.Close() 315 t.Fatalf("Dial: nil") 316 } 317 if resp == nil { 318 t.Fatalf("resp=nil, err=%v", err) 319 } 320 if resp.StatusCode != http.StatusForbidden { 321 t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden) 322 } 323 } 324 325 func TestDialBadHeader(t *testing.T) { 326 s := newServer(t) 327 defer s.Close() 328 329 for _, k := range []string{"Upgrade", 330 "Connection", 331 "Sec-Websocket-Key", 332 "Sec-Websocket-Version", 333 "Sec-Websocket-Protocol"} { 334 h := http.Header{} 335 h.Set(k, "bad") 336 ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) 337 if err == nil { 338 ws.Close() 339 t.Errorf("Dial with header %s returned nil", k) 340 } 341 } 342 } 343 344 func TestBadMethod(t *testing.T) { 345 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 346 ws, err := cstUpgrader.Upgrade(w, r, nil) 347 if err == nil { 348 t.Errorf("handshake succeeded, expect fail") 349 ws.Close() 350 } 351 })) 352 defer s.Close() 353 354 resp, err := http.PostForm(s.URL, url.Values{}) 355 if err != nil { 356 t.Fatalf("PostForm returned error %v", err) 357 } 358 resp.Body.Close() 359 if resp.StatusCode != http.StatusMethodNotAllowed { 360 t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed) 361 } 362 } 363 364 func TestHandshake(t *testing.T) { 365 s := newServer(t) 366 defer s.Close() 367 368 ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}}) 369 if err != nil { 370 t.Fatalf("Dial: %v", err) 371 } 372 defer ws.Close() 373 374 var self.Session.D string 375 for _, c := range resp.Cookies() { 376 if c.Name == "self.Session.D" { 377 self.Session.D = c.Value 378 } 379 } 380 if self.Session.D != "1234" { 381 t.Error("Set-Cookie not received from the server.") 382 } 383 384 if ws.Subprotocol() != "p1" { 385 t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol()) 386 } 387 sendRecv(t, ws) 388 } 389 390 func TestRespOnBadHandshake(t *testing.T) { 391 const expectedStatus = http.StatusGone 392 const expectedBody = "This is the response body." 393 394 s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { 395 w.WriteHeader(expectedStatus) 396 io.WriteString(w, expectedBody) 397 })) 398 defer s.Close() 399 400 ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) 401 if err == nil { 402 ws.Close() 403 t.Fatalf("Dial: nil") 404 } 405 406 if resp == nil { 407 t.Fatalf("resp=nil, err=%v", err) 408 } 409 410 if resp.StatusCode != expectedStatus { 411 t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) 412 } 413 414 p, err := ioutil.ReadAll(resp.Body) 415 if err != nil { 416 t.Fatalf("ReadFull(resp.Body) returned error %v", err) 417 } 418 419 if string(p) != expectedBody { 420 t.Errorf("resp.Body=%s, want %s", p, expectedBody) 421 } 422 } 423 424 // TestHostHeader confirms that the host header provided in the call to Dial is 425 // sent to the server. 426 func TestHostHeader(t *testing.T) { 427 s := newServer(t) 428 defer s.Close() 429 430 specifiedHost := make(chan string, 1) 431 origHandler := s.Server.Config.Handler 432 433 // Capture the request Host header. 434 s.Server.Config.Handler = http.HandlerFunc( 435 func(w http.ResponseWriter, r *http.Request) { 436 specifiedHost <- r.Host 437 origHandler.ServeHTTP(w, r) 438 }) 439 440 ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}}) 441 if err != nil { 442 t.Fatalf("Dial: %v", err) 443 } 444 defer ws.Close() 445 446 if gotHost := <-specifiedHost; gotHost != "testhost" { 447 t.Fatalf("gotHost = %q, want \"testhost\"", gotHost) 448 } 449 450 sendRecv(t, ws) 451 }