github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/gorilla/websocket/client_server_test.go (about)

     1  // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"crypto/tls"
     9  	"crypto/x509"
    10  	"encoding/base64"
    11  	"io"
    12  	"io/ioutil"
    13  	"net"
    14  	"net/http"
    15  	"net/http/httptest"
    16  	"net/url"
    17  	"reflect"
    18  	"strings"
    19  	"testing"
    20  	"time"
    21  )
    22  
    23  var cstUpgrader = Upgrader{
    24  	Subprotocols:    []string{"p0", "p1"},
    25  	ReadBufferSize:  1024,
    26  	WriteBufferSize: 1024,
    27  	Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
    28  		http.Error(w, reason.Error(), status)
    29  	},
    30  }
    31  
    32  var cstDialer = Dialer{
    33  	Subprotocols:    []string{"p1", "p2"},
    34  	ReadBufferSize:  1024,
    35  	WriteBufferSize: 1024,
    36  }
    37  
    38  type cstHandler struct{ *testing.T }
    39  
    40  type cstServer struct {
    41  	*httptest.Server
    42  	URL string
    43  }
    44  
    45  const (
    46  	cstPath       = "/a/b"
    47  	cstRawQuery   = "x=y"
    48  	cstRequestURI = cstPath + "?" + cstRawQuery
    49  )
    50  
    51  func newServer(t *testing.T) *cstServer {
    52  	var s cstServer
    53  	s.Server = httptest.NewServer(cstHandler{t})
    54  	s.Server.URL += cstRequestURI
    55  	s.URL = makeWsProto(s.Server.URL)
    56  	return &s
    57  }
    58  
    59  func newTLSServer(t *testing.T) *cstServer {
    60  	var s cstServer
    61  	s.Server = httptest.NewTLSServer(cstHandler{t})
    62  	s.Server.URL += cstRequestURI
    63  	s.URL = makeWsProto(s.Server.URL)
    64  	return &s
    65  }
    66  
    67  func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    68  	if r.URL.Path != cstPath {
    69  		t.Logf("path=%v, want %v", r.URL.Path, cstPath)
    70  		http.Error(w, "bad path", 400)
    71  		return
    72  	}
    73  	if r.URL.RawQuery != cstRawQuery {
    74  		t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
    75  		http.Error(w, "bad path", 400)
    76  		return
    77  	}
    78  	subprotos := Subprotocols(r)
    79  	if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
    80  		t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
    81  		http.Error(w, "bad protocol", 400)
    82  		return
    83  	}
    84  	ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"self.Session.D=1234"}})
    85  	if err != nil {
    86  		t.Logf("Upgrade: %v", err)
    87  		return
    88  	}
    89  	defer ws.Close()
    90  
    91  	if ws.Subprotocol() != "p1" {
    92  		t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
    93  		ws.Close()
    94  		return
    95  	}
    96  	op, rd, err := ws.NextReader()
    97  	if err != nil {
    98  		t.Logf("NextReader: %v", err)
    99  		return
   100  	}
   101  	wr, err := ws.NextWriter(op)
   102  	if err != nil {
   103  		t.Logf("NextWriter: %v", err)
   104  		return
   105  	}
   106  	if _, err = io.Copy(wr, rd); err != nil {
   107  		t.Logf("NextWriter: %v", err)
   108  		return
   109  	}
   110  	if err := wr.Close(); err != nil {
   111  		t.Logf("Close: %v", err)
   112  		return
   113  	}
   114  }
   115  
   116  func makeWsProto(s string) string {
   117  	return "ws" + strings.TrimPrefix(s, "http")
   118  }
   119  
   120  func sendRecv(t *testing.T, ws *Conn) {
   121  	const message = "Hello World!"
   122  	if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
   123  		t.Fatalf("SetWriteDeadline: %v", err)
   124  	}
   125  	if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
   126  		t.Fatalf("WriteMessage: %v", err)
   127  	}
   128  	if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
   129  		t.Fatalf("SetReadDeadline: %v", err)
   130  	}
   131  	_, p, err := ws.ReadMessage()
   132  	if err != nil {
   133  		t.Fatalf("ReadMessage: %v", err)
   134  	}
   135  	if string(p) != message {
   136  		t.Fatalf("message=%s, want %s", p, message)
   137  	}
   138  }
   139  
   140  func TestProxyDial(t *testing.T) {
   141  
   142  	s := newServer(t)
   143  	defer s.Close()
   144  
   145  	surl, _ := url.Parse(s.URL)
   146  
   147  	cstDialer.Proxy = http.ProxyURL(surl)
   148  
   149  	connect := false
   150  	origHandler := s.Server.Config.Handler
   151  
   152  	// Capture the request Host header.
   153  	s.Server.Config.Handler = http.HandlerFunc(
   154  		func(w http.ResponseWriter, r *http.Request) {
   155  			if r.Method == "CONNECT" {
   156  				connect = true
   157  				w.WriteHeader(200)
   158  				return
   159  			}
   160  
   161  			if !connect {
   162  				t.Log("connect not recieved")
   163  				http.Error(w, "connect not recieved", 405)
   164  				return
   165  			}
   166  			origHandler.ServeHTTP(w, r)
   167  		})
   168  
   169  	ws, _, err := cstDialer.Dial(s.URL, nil)
   170  	if err != nil {
   171  		t.Fatalf("Dial: %v", err)
   172  	}
   173  	defer ws.Close()
   174  	sendRecv(t, ws)
   175  
   176  	cstDialer.Proxy = http.ProxyFromEnvironment
   177  }
   178  
   179  func TestProxyAuthorizationDial(t *testing.T) {
   180  	s := newServer(t)
   181  	defer s.Close()
   182  
   183  	surl, _ := url.Parse(s.URL)
   184  	surl.User = url.UserPassword("username", "password")
   185  	cstDialer.Proxy = http.ProxyURL(surl)
   186  
   187  	connect := false
   188  	origHandler := s.Server.Config.Handler
   189  
   190  	// Capture the request Host header.
   191  	s.Server.Config.Handler = http.HandlerFunc(
   192  		func(w http.ResponseWriter, r *http.Request) {
   193  			proxyAuth := r.Header.Get("Proxy-Authorization")
   194  			expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
   195  			if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
   196  				connect = true
   197  				w.WriteHeader(200)
   198  				return
   199  			}
   200  
   201  			if !connect {
   202  				t.Log("connect with proxy authorization not recieved")
   203  				http.Error(w, "connect with proxy authorization not recieved", 405)
   204  				return
   205  			}
   206  			origHandler.ServeHTTP(w, r)
   207  		})
   208  
   209  	ws, _, err := cstDialer.Dial(s.URL, nil)
   210  	if err != nil {
   211  		t.Fatalf("Dial: %v", err)
   212  	}
   213  	defer ws.Close()
   214  	sendRecv(t, ws)
   215  
   216  	cstDialer.Proxy = http.ProxyFromEnvironment
   217  }
   218  
   219  func TestDial(t *testing.T) {
   220  	s := newServer(t)
   221  	defer s.Close()
   222  
   223  	ws, _, err := cstDialer.Dial(s.URL, nil)
   224  	if err != nil {
   225  		t.Fatalf("Dial: %v", err)
   226  	}
   227  	defer ws.Close()
   228  	sendRecv(t, ws)
   229  }
   230  
   231  func TestDialTLS(t *testing.T) {
   232  	s := newTLSServer(t)
   233  	defer s.Close()
   234  
   235  	certs := x509.NewCertPool()
   236  	for _, c := range s.TLS.Certificates {
   237  		roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
   238  		if err != nil {
   239  			t.Fatalf("error parsing server's root cert: %v", err)
   240  		}
   241  		for _, root := range roots {
   242  			certs.AddCert(root)
   243  		}
   244  	}
   245  
   246  	u, _ := url.Parse(s.URL)
   247  	d := cstDialer
   248  	d.NetDial = func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) }
   249  	d.TLSClientConfig = &tls.Config{RootCAs: certs}
   250  	ws, _, err := d.Dial("wss://example.com"+cstRequestURI, nil)
   251  	if err != nil {
   252  		t.Fatalf("Dial: %v", err)
   253  	}
   254  	defer ws.Close()
   255  	sendRecv(t, ws)
   256  }
   257  
   258  func xTestDialTLSBadCert(t *testing.T) {
   259  	// This test is deactivated because of noisy logging from the net/http package.
   260  	s := newTLSServer(t)
   261  	defer s.Close()
   262  
   263  	ws, _, err := cstDialer.Dial(s.URL, nil)
   264  	if err == nil {
   265  		ws.Close()
   266  		t.Fatalf("Dial: nil")
   267  	}
   268  }
   269  
   270  func xTestDialTLSNoVerify(t *testing.T) {
   271  	s := newTLSServer(t)
   272  	defer s.Close()
   273  
   274  	d := cstDialer
   275  	d.TLSClientConfig = &tls.Config{InsecureSkipVerify: true}
   276  	ws, _, err := d.Dial(s.URL, nil)
   277  	if err != nil {
   278  		t.Fatalf("Dial: %v", err)
   279  	}
   280  	defer ws.Close()
   281  	sendRecv(t, ws)
   282  }
   283  
   284  func TestDialTimeout(t *testing.T) {
   285  	s := newServer(t)
   286  	defer s.Close()
   287  
   288  	d := cstDialer
   289  	d.HandshakeTimeout = -1
   290  	ws, _, err := d.Dial(s.URL, nil)
   291  	if err == nil {
   292  		ws.Close()
   293  		t.Fatalf("Dial: nil")
   294  	}
   295  }
   296  
   297  func TestDialBadScheme(t *testing.T) {
   298  	s := newServer(t)
   299  	defer s.Close()
   300  
   301  	ws, _, err := cstDialer.Dial(s.Server.URL, nil)
   302  	if err == nil {
   303  		ws.Close()
   304  		t.Fatalf("Dial: nil")
   305  	}
   306  }
   307  
   308  func TestDialBadOrigin(t *testing.T) {
   309  	s := newServer(t)
   310  	defer s.Close()
   311  
   312  	ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
   313  	if err == nil {
   314  		ws.Close()
   315  		t.Fatalf("Dial: nil")
   316  	}
   317  	if resp == nil {
   318  		t.Fatalf("resp=nil, err=%v", err)
   319  	}
   320  	if resp.StatusCode != http.StatusForbidden {
   321  		t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden)
   322  	}
   323  }
   324  
   325  func TestDialBadHeader(t *testing.T) {
   326  	s := newServer(t)
   327  	defer s.Close()
   328  
   329  	for _, k := range []string{"Upgrade",
   330  		"Connection",
   331  		"Sec-Websocket-Key",
   332  		"Sec-Websocket-Version",
   333  		"Sec-Websocket-Protocol"} {
   334  		h := http.Header{}
   335  		h.Set(k, "bad")
   336  		ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
   337  		if err == nil {
   338  			ws.Close()
   339  			t.Errorf("Dial with header %s returned nil", k)
   340  		}
   341  	}
   342  }
   343  
   344  func TestBadMethod(t *testing.T) {
   345  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   346  		ws, err := cstUpgrader.Upgrade(w, r, nil)
   347  		if err == nil {
   348  			t.Errorf("handshake succeeded, expect fail")
   349  			ws.Close()
   350  		}
   351  	}))
   352  	defer s.Close()
   353  
   354  	resp, err := http.PostForm(s.URL, url.Values{})
   355  	if err != nil {
   356  		t.Fatalf("PostForm returned error %v", err)
   357  	}
   358  	resp.Body.Close()
   359  	if resp.StatusCode != http.StatusMethodNotAllowed {
   360  		t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
   361  	}
   362  }
   363  
   364  func TestHandshake(t *testing.T) {
   365  	s := newServer(t)
   366  	defer s.Close()
   367  
   368  	ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
   369  	if err != nil {
   370  		t.Fatalf("Dial: %v", err)
   371  	}
   372  	defer ws.Close()
   373  
   374  	var self.Session.D string
   375  	for _, c := range resp.Cookies() {
   376  		if c.Name == "self.Session.D" {
   377  			self.Session.D = c.Value
   378  		}
   379  	}
   380  	if self.Session.D != "1234" {
   381  		t.Error("Set-Cookie not received from the server.")
   382  	}
   383  
   384  	if ws.Subprotocol() != "p1" {
   385  		t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
   386  	}
   387  	sendRecv(t, ws)
   388  }
   389  
   390  func TestRespOnBadHandshake(t *testing.T) {
   391  	const expectedStatus = http.StatusGone
   392  	const expectedBody = "This is the response body."
   393  
   394  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   395  		w.WriteHeader(expectedStatus)
   396  		io.WriteString(w, expectedBody)
   397  	}))
   398  	defer s.Close()
   399  
   400  	ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
   401  	if err == nil {
   402  		ws.Close()
   403  		t.Fatalf("Dial: nil")
   404  	}
   405  
   406  	if resp == nil {
   407  		t.Fatalf("resp=nil, err=%v", err)
   408  	}
   409  
   410  	if resp.StatusCode != expectedStatus {
   411  		t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
   412  	}
   413  
   414  	p, err := ioutil.ReadAll(resp.Body)
   415  	if err != nil {
   416  		t.Fatalf("ReadFull(resp.Body) returned error %v", err)
   417  	}
   418  
   419  	if string(p) != expectedBody {
   420  		t.Errorf("resp.Body=%s, want %s", p, expectedBody)
   421  	}
   422  }
   423  
   424  // TestHostHeader confirms that the host header provided in the call to Dial is
   425  // sent to the server.
   426  func TestHostHeader(t *testing.T) {
   427  	s := newServer(t)
   428  	defer s.Close()
   429  
   430  	specifiedHost := make(chan string, 1)
   431  	origHandler := s.Server.Config.Handler
   432  
   433  	// Capture the request Host header.
   434  	s.Server.Config.Handler = http.HandlerFunc(
   435  		func(w http.ResponseWriter, r *http.Request) {
   436  			specifiedHost <- r.Host
   437  			origHandler.ServeHTTP(w, r)
   438  		})
   439  
   440  	ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}})
   441  	if err != nil {
   442  		t.Fatalf("Dial: %v", err)
   443  	}
   444  	defer ws.Close()
   445  
   446  	if gotHost := <-specifiedHost; gotHost != "testhost" {
   447  		t.Fatalf("gotHost = %q, want \"testhost\"", gotHost)
   448  	}
   449  
   450  	sendRecv(t, ws)
   451  }