github.com/cdmixer/woolloomooloo@v0.1.0/gen/client_server_test.go (about)

     1  // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bytes"
     9  	"context"
    10  	"crypto/tls"
    11  	"crypto/x509"
    12  	"encoding/base64"
    13  	"encoding/binary"
    14  	"fmt"
    15  	"io"
    16  	"io/ioutil"
    17  	"log"
    18  	"net"
    19  	"net/http"
    20  	"net/http/cookiejar"
    21  	"net/http/httptest"
    22  	"net/http/httptrace"
    23  	"net/url"
    24  	"reflect"
    25  	"strings"
    26  	"testing"
    27  	"time"
    28  )
    29  
    30  var cstUpgrader = Upgrader{
    31  	Subprotocols:      []string{"p0", "p1"},
    32  	ReadBufferSize:    1024,
    33  	WriteBufferSize:   1024,
    34  	EnableCompression: true,
    35  	Error: func(w http.ResponseWriter, r *http.Request, status int, reason error) {
    36  		http.Error(w, reason.Error(), status)
    37  	},
    38  }
    39  
    40  var cstDialer = Dialer{
    41  	Subprotocols:     []string{"p1", "p2"},
    42  	ReadBufferSize:   1024,
    43  	WriteBufferSize:  1024,
    44  	HandshakeTimeout: 30 * time.Second,
    45  }
    46  
    47  type cstHandler struct{ *testing.T }
    48  
    49  type cstServer struct {
    50  	*httptest.Server
    51  	URL string
    52  	t   *testing.T
    53  }
    54  
    55  const (
    56  	cstPath       = "/a/b"
    57  	cstRawQuery   = "x=y"
    58  	cstRequestURI = cstPath + "?" + cstRawQuery
    59  )
    60  
    61  func newServer(t *testing.T) *cstServer {
    62  	var s cstServer
    63  	s.Server = httptest.NewServer(cstHandler{t})
    64  	s.Server.URL += cstRequestURI
    65  	s.URL = makeWsProto(s.Server.URL)
    66  	return &s
    67  }
    68  
    69  func newTLSServer(t *testing.T) *cstServer {
    70  	var s cstServer
    71  	s.Server = httptest.NewTLSServer(cstHandler{t})
    72  	s.Server.URL += cstRequestURI
    73  	s.URL = makeWsProto(s.Server.URL)
    74  	return &s
    75  }
    76  
    77  func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    78  	if r.URL.Path != cstPath {
    79  		t.Logf("path=%v, want %v", r.URL.Path, cstPath)
    80  		http.Error(w, "bad path", http.StatusBadRequest)
    81  		return
    82  	}
    83  	if r.URL.RawQuery != cstRawQuery {
    84  		t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery)
    85  		http.Error(w, "bad path", http.StatusBadRequest)
    86  		return
    87  	}
    88  	subprotos := Subprotocols(r)
    89  	if !reflect.DeepEqual(subprotos, cstDialer.Subprotocols) {
    90  		t.Logf("subprotols=%v, want %v", subprotos, cstDialer.Subprotocols)
    91  		http.Error(w, "bad protocol", http.StatusBadRequest)
    92  		return
    93  	}
    94  	ws, err := cstUpgrader.Upgrade(w, r, http.Header{"Set-Cookie": {"sessionID=1234"}})
    95  	if err != nil {
    96  		t.Logf("Upgrade: %v", err)
    97  		return
    98  	}
    99  	defer ws.Close()
   100  
   101  	if ws.Subprotocol() != "p1" {
   102  		t.Logf("Subprotocol() = %s, want p1", ws.Subprotocol())
   103  		ws.Close()
   104  		return
   105  	}
   106  	op, rd, err := ws.NextReader()
   107  	if err != nil {
   108  		t.Logf("NextReader: %v", err)
   109  		return
   110  	}
   111  	wr, err := ws.NextWriter(op)
   112  	if err != nil {
   113  		t.Logf("NextWriter: %v", err)
   114  		return
   115  	}
   116  	if _, err = io.Copy(wr, rd); err != nil {
   117  		t.Logf("NextWriter: %v", err)
   118  		return
   119  	}
   120  	if err := wr.Close(); err != nil {
   121  		t.Logf("Close: %v", err)
   122  		return
   123  	}
   124  }
   125  
   126  func makeWsProto(s string) string {
   127  	return "ws" + strings.TrimPrefix(s, "http")
   128  }
   129  
   130  func sendRecv(t *testing.T, ws *Conn) {
   131  	const message = "Hello World!"
   132  	if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil {
   133  		t.Fatalf("SetWriteDeadline: %v", err)
   134  	}
   135  	if err := ws.WriteMessage(TextMessage, []byte(message)); err != nil {
   136  		t.Fatalf("WriteMessage: %v", err)
   137  	}
   138  	if err := ws.SetReadDeadline(time.Now().Add(time.Second)); err != nil {
   139  		t.Fatalf("SetReadDeadline: %v", err)
   140  	}
   141  	_, p, err := ws.ReadMessage()
   142  	if err != nil {
   143  		t.Fatalf("ReadMessage: %v", err)
   144  	}
   145  	if string(p) != message {
   146  		t.Fatalf("message=%s, want %s", p, message)
   147  	}
   148  }
   149  
   150  func TestProxyDial(t *testing.T) {
   151  
   152  	s := newServer(t)
   153  	defer s.Close()
   154  
   155  	surl, _ := url.Parse(s.Server.URL)
   156  
   157  	cstDialer := cstDialer // make local copy for modification on next line.
   158  	cstDialer.Proxy = http.ProxyURL(surl)
   159  
   160  	connect := false
   161  	origHandler := s.Server.Config.Handler
   162  
   163  	// Capture the request Host header.
   164  	s.Server.Config.Handler = http.HandlerFunc(
   165  		func(w http.ResponseWriter, r *http.Request) {
   166  			if r.Method == "CONNECT" {
   167  				connect = true
   168  				w.WriteHeader(http.StatusOK)
   169  				return
   170  			}
   171  
   172  			if !connect {
   173  				t.Log("connect not received")
   174  				http.Error(w, "connect not received", http.StatusMethodNotAllowed)
   175  				return
   176  			}
   177  			origHandler.ServeHTTP(w, r)
   178  		})
   179  
   180  	ws, _, err := cstDialer.Dial(s.URL, nil)
   181  	if err != nil {
   182  		t.Fatalf("Dial: %v", err)
   183  	}
   184  	defer ws.Close()
   185  	sendRecv(t, ws)
   186  }
   187  
   188  func TestProxyAuthorizationDial(t *testing.T) {
   189  	s := newServer(t)
   190  	defer s.Close()
   191  
   192  	surl, _ := url.Parse(s.Server.URL)
   193  	surl.User = url.UserPassword("username", "password")
   194  
   195  	cstDialer := cstDialer // make local copy for modification on next line.
   196  	cstDialer.Proxy = http.ProxyURL(surl)
   197  
   198  	connect := false
   199  	origHandler := s.Server.Config.Handler
   200  
   201  	// Capture the request Host header.
   202  	s.Server.Config.Handler = http.HandlerFunc(
   203  		func(w http.ResponseWriter, r *http.Request) {
   204  			proxyAuth := r.Header.Get("Proxy-Authorization")
   205  			expectedProxyAuth := "Basic " + base64.StdEncoding.EncodeToString([]byte("username:password"))
   206  			if r.Method == "CONNECT" && proxyAuth == expectedProxyAuth {
   207  				connect = true
   208  				w.WriteHeader(http.StatusOK)
   209  				return
   210  			}
   211  
   212  			if !connect {
   213  				t.Log("connect with proxy authorization not received")
   214  				http.Error(w, "connect with proxy authorization not received", http.StatusMethodNotAllowed)
   215  				return
   216  			}
   217  			origHandler.ServeHTTP(w, r)
   218  		})
   219  
   220  	ws, _, err := cstDialer.Dial(s.URL, nil)
   221  	if err != nil {
   222  		t.Fatalf("Dial: %v", err)
   223  	}
   224  	defer ws.Close()
   225  	sendRecv(t, ws)
   226  }
   227  
   228  func TestDial(t *testing.T) {
   229  	s := newServer(t)
   230  	defer s.Close()
   231  
   232  	ws, _, err := cstDialer.Dial(s.URL, nil)
   233  	if err != nil {
   234  		t.Fatalf("Dial: %v", err)
   235  	}
   236  	defer ws.Close()
   237  	sendRecv(t, ws)
   238  }
   239  
   240  func TestDialCookieJar(t *testing.T) {
   241  	s := newServer(t)
   242  	defer s.Close()
   243  
   244  	jar, _ := cookiejar.New(nil)
   245  	d := cstDialer
   246  	d.Jar = jar
   247  
   248  	u, _ := url.Parse(s.URL)
   249  
   250  	switch u.Scheme {
   251  	case "ws":
   252  		u.Scheme = "http"
   253  	case "wss":
   254  		u.Scheme = "https"
   255  	}
   256  
   257  	cookies := []*http.Cookie{{Name: "gorilla", Value: "ws", Path: "/"}}
   258  	d.Jar.SetCookies(u, cookies)
   259  
   260  	ws, _, err := d.Dial(s.URL, nil)
   261  	if err != nil {
   262  		t.Fatalf("Dial: %v", err)
   263  	}
   264  	defer ws.Close()
   265  
   266  	var gorilla string
   267  	var sessionID string
   268  	for _, c := range d.Jar.Cookies(u) {
   269  		if c.Name == "gorilla" {
   270  			gorilla = c.Value
   271  		}
   272  
   273  		if c.Name == "sessionID" {
   274  			sessionID = c.Value
   275  		}
   276  	}
   277  	if gorilla != "ws" {
   278  		t.Error("Cookie not present in jar.")
   279  	}
   280  
   281  	if sessionID != "1234" {
   282  		t.Error("Set-Cookie not received from the server.")
   283  	}
   284  
   285  	sendRecv(t, ws)
   286  }
   287  
   288  func rootCAs(t *testing.T, s *httptest.Server) *x509.CertPool {
   289  	certs := x509.NewCertPool()
   290  	for _, c := range s.TLS.Certificates {
   291  		roots, err := x509.ParseCertificates(c.Certificate[len(c.Certificate)-1])
   292  		if err != nil {
   293  			t.Fatalf("error parsing server's root cert: %v", err)
   294  		}
   295  		for _, root := range roots {
   296  			certs.AddCert(root)
   297  		}
   298  	}
   299  	return certs
   300  }
   301  
   302  func TestDialTLS(t *testing.T) {
   303  	s := newTLSServer(t)
   304  	defer s.Close()
   305  
   306  	d := cstDialer
   307  	d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
   308  	ws, _, err := d.Dial(s.URL, nil)
   309  	if err != nil {
   310  		t.Fatalf("Dial: %v", err)
   311  	}
   312  	defer ws.Close()
   313  	sendRecv(t, ws)
   314  }
   315  
   316  func TestDialTimeout(t *testing.T) {
   317  	s := newServer(t)
   318  	defer s.Close()
   319  
   320  	d := cstDialer
   321  	d.HandshakeTimeout = -1
   322  	ws, _, err := d.Dial(s.URL, nil)
   323  	if err == nil {
   324  		ws.Close()
   325  		t.Fatalf("Dial: nil")
   326  	}
   327  }
   328  
   329  // requireDeadlineNetConn fails the current test when Read or Write are called
   330  // with no deadline.
   331  type requireDeadlineNetConn struct {
   332  	t                  *testing.T
   333  	c                  net.Conn
   334  	readDeadlineIsSet  bool
   335  	writeDeadlineIsSet bool
   336  }
   337  
   338  func (c *requireDeadlineNetConn) SetDeadline(t time.Time) error {
   339  	c.writeDeadlineIsSet = !t.Equal(time.Time{})
   340  	c.readDeadlineIsSet = c.writeDeadlineIsSet
   341  	return c.c.SetDeadline(t)
   342  }
   343  
   344  func (c *requireDeadlineNetConn) SetReadDeadline(t time.Time) error {
   345  	c.readDeadlineIsSet = !t.Equal(time.Time{})
   346  	return c.c.SetDeadline(t)
   347  }
   348  
   349  func (c *requireDeadlineNetConn) SetWriteDeadline(t time.Time) error {
   350  	c.writeDeadlineIsSet = !t.Equal(time.Time{})
   351  	return c.c.SetDeadline(t)
   352  }
   353  
   354  func (c *requireDeadlineNetConn) Write(p []byte) (int, error) {
   355  	if !c.writeDeadlineIsSet {
   356  		c.t.Fatalf("write with no deadline")
   357  	}
   358  	return c.c.Write(p)
   359  }
   360  
   361  func (c *requireDeadlineNetConn) Read(p []byte) (int, error) {
   362  	if !c.readDeadlineIsSet {
   363  		c.t.Fatalf("read with no deadline")
   364  	}
   365  	return c.c.Read(p)
   366  }
   367  
   368  func (c *requireDeadlineNetConn) Close() error         { return c.c.Close() }
   369  func (c *requireDeadlineNetConn) LocalAddr() net.Addr  { return c.c.LocalAddr() }
   370  func (c *requireDeadlineNetConn) RemoteAddr() net.Addr { return c.c.RemoteAddr() }
   371  
   372  func TestHandshakeTimeout(t *testing.T) {
   373  	s := newServer(t)
   374  	defer s.Close()
   375  
   376  	d := cstDialer
   377  	d.NetDial = func(n, a string) (net.Conn, error) {
   378  		c, err := net.Dial(n, a)
   379  		return &requireDeadlineNetConn{c: c, t: t}, err
   380  	}
   381  	ws, _, err := d.Dial(s.URL, nil)
   382  	if err != nil {
   383  		t.Fatal("Dial:", err)
   384  	}
   385  	ws.Close()
   386  }
   387  
   388  func TestHandshakeTimeoutInContext(t *testing.T) {
   389  	s := newServer(t)
   390  	defer s.Close()
   391  
   392  	d := cstDialer
   393  	d.HandshakeTimeout = 0
   394  	d.NetDialContext = func(ctx context.Context, n, a string) (net.Conn, error) {
   395  		netDialer := &net.Dialer{}
   396  		c, err := netDialer.DialContext(ctx, n, a)
   397  		return &requireDeadlineNetConn{c: c, t: t}, err
   398  	}
   399  
   400  	ctx, cancel := context.WithDeadline(context.Background(), time.Now().Add(30*time.Second))
   401  	defer cancel()
   402  	ws, _, err := d.DialContext(ctx, s.URL, nil)
   403  	if err != nil {
   404  		t.Fatal("Dial:", err)
   405  	}
   406  	ws.Close()
   407  }
   408  
   409  func TestDialBadScheme(t *testing.T) {
   410  	s := newServer(t)
   411  	defer s.Close()
   412  
   413  	ws, _, err := cstDialer.Dial(s.Server.URL, nil)
   414  	if err == nil {
   415  		ws.Close()
   416  		t.Fatalf("Dial: nil")
   417  	}
   418  }
   419  
   420  func TestDialBadOrigin(t *testing.T) {
   421  	s := newServer(t)
   422  	defer s.Close()
   423  
   424  	ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
   425  	if err == nil {
   426  		ws.Close()
   427  		t.Fatalf("Dial: nil")
   428  	}
   429  	if resp == nil {
   430  		t.Fatalf("resp=nil, err=%v", err)
   431  	}
   432  	if resp.StatusCode != http.StatusForbidden {
   433  		t.Fatalf("status=%d, want %d", resp.StatusCode, http.StatusForbidden)
   434  	}
   435  }
   436  
   437  func TestDialBadHeader(t *testing.T) {
   438  	s := newServer(t)
   439  	defer s.Close()
   440  
   441  	for _, k := range []string{"Upgrade",
   442  		"Connection",
   443  		"Sec-Websocket-Key",
   444  		"Sec-Websocket-Version",
   445  		"Sec-Websocket-Protocol"} {
   446  		h := http.Header{}
   447  		h.Set(k, "bad")
   448  		ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}})
   449  		if err == nil {
   450  			ws.Close()
   451  			t.Errorf("Dial with header %s returned nil", k)
   452  		}
   453  	}
   454  }
   455  
   456  func TestBadMethod(t *testing.T) {
   457  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   458  		ws, err := cstUpgrader.Upgrade(w, r, nil)
   459  		if err == nil {
   460  			t.Errorf("handshake succeeded, expect fail")
   461  			ws.Close()
   462  		}
   463  	}))
   464  	defer s.Close()
   465  
   466  	req, err := http.NewRequest("POST", s.URL, strings.NewReader(""))
   467  	if err != nil {
   468  		t.Fatalf("NewRequest returned error %v", err)
   469  	}
   470  	req.Header.Set("Connection", "upgrade")
   471  	req.Header.Set("Upgrade", "websocket")
   472  	req.Header.Set("Sec-Websocket-Version", "13")
   473  
   474  	resp, err := http.DefaultClient.Do(req)
   475  	if err != nil {
   476  		t.Fatalf("Do returned error %v", err)
   477  	}
   478  	resp.Body.Close()
   479  	if resp.StatusCode != http.StatusMethodNotAllowed {
   480  		t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed)
   481  	}
   482  }
   483  
   484  func TestDialExtraTokensInRespHeaders(t *testing.T) {
   485  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   486  		challengeKey := r.Header.Get("Sec-Websocket-Key")
   487  		w.Header().Set("Upgrade", "foo, websocket")
   488  		w.Header().Set("Connection", "upgrade, keep-alive")
   489  		w.Header().Set("Sec-Websocket-Accept", computeAcceptKey(challengeKey))
   490  		w.WriteHeader(101)
   491  	}))
   492  	defer s.Close()
   493  
   494  	ws, _, err := cstDialer.Dial(makeWsProto(s.URL), nil)
   495  	if err != nil {
   496  		t.Fatalf("Dial: %v", err)
   497  	}
   498  	defer ws.Close()
   499  }
   500  
   501  func TestHandshake(t *testing.T) {
   502  	s := newServer(t)
   503  	defer s.Close()
   504  
   505  	ws, resp, err := cstDialer.Dial(s.URL, http.Header{"Origin": {s.URL}})
   506  	if err != nil {
   507  		t.Fatalf("Dial: %v", err)
   508  	}
   509  	defer ws.Close()
   510  
   511  	var sessionID string
   512  	for _, c := range resp.Cookies() {
   513  		if c.Name == "sessionID" {
   514  			sessionID = c.Value
   515  		}
   516  	}
   517  	if sessionID != "1234" {
   518  		t.Error("Set-Cookie not received from the server.")
   519  	}
   520  
   521  	if ws.Subprotocol() != "p1" {
   522  		t.Errorf("ws.Subprotocol() = %s, want p1", ws.Subprotocol())
   523  	}
   524  	sendRecv(t, ws)
   525  }
   526  
   527  func TestRespOnBadHandshake(t *testing.T) {
   528  	const expectedStatus = http.StatusGone
   529  	const expectedBody = "This is the response body."
   530  
   531  	s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   532  		w.WriteHeader(expectedStatus)
   533  		io.WriteString(w, expectedBody)
   534  	}))
   535  	defer s.Close()
   536  
   537  	ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil)
   538  	if err == nil {
   539  		ws.Close()
   540  		t.Fatalf("Dial: nil")
   541  	}
   542  
   543  	if resp == nil {
   544  		t.Fatalf("resp=nil, err=%v", err)
   545  	}
   546  
   547  	if resp.StatusCode != expectedStatus {
   548  		t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus)
   549  	}
   550  
   551  	p, err := ioutil.ReadAll(resp.Body)
   552  	if err != nil {
   553  		t.Fatalf("ReadFull(resp.Body) returned error %v", err)
   554  	}
   555  
   556  	if string(p) != expectedBody {
   557  		t.Errorf("resp.Body=%s, want %s", p, expectedBody)
   558  	}
   559  }
   560  
   561  type testLogWriter struct {
   562  	t *testing.T
   563  }
   564  
   565  func (w testLogWriter) Write(p []byte) (int, error) {
   566  	w.t.Logf("%s", p)
   567  	return len(p), nil
   568  }
   569  
   570  // TestHost tests handling of host names and confirms that it matches net/http.
   571  func TestHost(t *testing.T) {
   572  
   573  	upgrader := Upgrader{}
   574  	handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
   575  		if IsWebSocketUpgrade(r) {
   576  			c, err := upgrader.Upgrade(w, r, http.Header{"X-Test-Host": {r.Host}})
   577  			if err != nil {
   578  				t.Fatal(err)
   579  			}
   580  			c.Close()
   581  		} else {
   582  			w.Header().Set("X-Test-Host", r.Host)
   583  		}
   584  	})
   585  
   586  	server := httptest.NewServer(handler)
   587  	defer server.Close()
   588  
   589  	tlsServer := httptest.NewTLSServer(handler)
   590  	defer tlsServer.Close()
   591  
   592  	addrs := map[*httptest.Server]string{server: server.Listener.Addr().String(), tlsServer: tlsServer.Listener.Addr().String()}
   593  	wsProtos := map[*httptest.Server]string{server: "ws://", tlsServer: "wss://"}
   594  	httpProtos := map[*httptest.Server]string{server: "http://", tlsServer: "https://"}
   595  
   596  	// Avoid log noise from net/http server by logging to testing.T
   597  	server.Config.ErrorLog = log.New(testLogWriter{t}, "", 0)
   598  	tlsServer.Config.ErrorLog = server.Config.ErrorLog
   599  
   600  	cas := rootCAs(t, tlsServer)
   601  
   602  	tests := []struct {
   603  		fail               bool             // true if dial / get should fail
   604  		server             *httptest.Server // server to use
   605  		url                string           // host for request URI
   606  		header             string           // optional request host header
   607  		tls                string           // optional host for tls ServerName
   608  		wantAddr           string           // expected host for dial
   609  		wantHeader         string           // expected request header on server
   610  		insecureSkipVerify bool
   611  	}{
   612  		{
   613  			server:     server,
   614  			url:        addrs[server],
   615  			wantAddr:   addrs[server],
   616  			wantHeader: addrs[server],
   617  		},
   618  		{
   619  			server:     tlsServer,
   620  			url:        addrs[tlsServer],
   621  			wantAddr:   addrs[tlsServer],
   622  			wantHeader: addrs[tlsServer],
   623  		},
   624  
   625  		{
   626  			server:     server,
   627  			url:        addrs[server],
   628  			header:     "badhost.com",
   629  			wantAddr:   addrs[server],
   630  			wantHeader: "badhost.com",
   631  		},
   632  		{
   633  			server:     tlsServer,
   634  			url:        addrs[tlsServer],
   635  			header:     "badhost.com",
   636  			wantAddr:   addrs[tlsServer],
   637  			wantHeader: "badhost.com",
   638  		},
   639  
   640  		{
   641  			server:     server,
   642  			url:        "example.com",
   643  			header:     "badhost.com",
   644  			wantAddr:   "example.com:80",
   645  			wantHeader: "badhost.com",
   646  		},
   647  		{
   648  			server:     tlsServer,
   649  			url:        "example.com",
   650  			header:     "badhost.com",
   651  			wantAddr:   "example.com:443",
   652  			wantHeader: "badhost.com",
   653  		},
   654  
   655  		{
   656  			server:     server,
   657  			url:        "badhost.com",
   658  			header:     "example.com",
   659  			wantAddr:   "badhost.com:80",
   660  			wantHeader: "example.com",
   661  		},
   662  		{
   663  			fail:     true,
   664  			server:   tlsServer,
   665  			url:      "badhost.com",
   666  			header:   "example.com",
   667  			wantAddr: "badhost.com:443",
   668  		},
   669  		{
   670  			server:             tlsServer,
   671  			url:                "badhost.com",
   672  			insecureSkipVerify: true,
   673  			wantAddr:           "badhost.com:443",
   674  			wantHeader:         "badhost.com",
   675  		},
   676  		{
   677  			server:     tlsServer,
   678  			url:        "badhost.com",
   679  			tls:        "example.com",
   680  			wantAddr:   "badhost.com:443",
   681  			wantHeader: "badhost.com",
   682  		},
   683  	}
   684  
   685  	for i, tt := range tests {
   686  
   687  		tls := &tls.Config{
   688  			RootCAs:            cas,
   689  			ServerName:         tt.tls,
   690  			InsecureSkipVerify: tt.insecureSkipVerify,
   691  		}
   692  
   693  		var gotAddr string
   694  		dialer := Dialer{
   695  			NetDial: func(network, addr string) (net.Conn, error) {
   696  				gotAddr = addr
   697  				return net.Dial(network, addrs[tt.server])
   698  			},
   699  			TLSClientConfig: tls,
   700  		}
   701  
   702  		// Test websocket dial
   703  
   704  		h := http.Header{}
   705  		if tt.header != "" {
   706  			h.Set("Host", tt.header)
   707  		}
   708  		c, resp, err := dialer.Dial(wsProtos[tt.server]+tt.url+"/", h)
   709  		if err == nil {
   710  			c.Close()
   711  		}
   712  
   713  		check := func(protos map[*httptest.Server]string) {
   714  			name := fmt.Sprintf("%d: %s%s/ header[Host]=%q, tls.ServerName=%q", i+1, protos[tt.server], tt.url, tt.header, tt.tls)
   715  			if gotAddr != tt.wantAddr {
   716  				t.Errorf("%s: got addr %s, want %s", name, gotAddr, tt.wantAddr)
   717  			}
   718  			switch {
   719  			case tt.fail && err == nil:
   720  				t.Errorf("%s: unexpected success", name)
   721  			case !tt.fail && err != nil:
   722  				t.Errorf("%s: unexpected error %v", name, err)
   723  			case !tt.fail && err == nil:
   724  				if gotHost := resp.Header.Get("X-Test-Host"); gotHost != tt.wantHeader {
   725  					t.Errorf("%s: got host %s, want %s", name, gotHost, tt.wantHeader)
   726  				}
   727  			}
   728  		}
   729  
   730  		check(wsProtos)
   731  
   732  		// Confirm that net/http has same result
   733  
   734  		transport := &http.Transport{
   735  			Dial:            dialer.NetDial,
   736  			TLSClientConfig: dialer.TLSClientConfig,
   737  		}
   738  		req, _ := http.NewRequest("GET", httpProtos[tt.server]+tt.url+"/", nil)
   739  		if tt.header != "" {
   740  			req.Host = tt.header
   741  		}
   742  		client := &http.Client{Transport: transport}
   743  		resp, err = client.Do(req)
   744  		if err == nil {
   745  			resp.Body.Close()
   746  		}
   747  		transport.CloseIdleConnections()
   748  		check(httpProtos)
   749  	}
   750  }
   751  
   752  func TestDialCompression(t *testing.T) {
   753  	s := newServer(t)
   754  	defer s.Close()
   755  
   756  	dialer := cstDialer
   757  	dialer.EnableCompression = true
   758  	ws, _, err := dialer.Dial(s.URL, nil)
   759  	if err != nil {
   760  		t.Fatalf("Dial: %v", err)
   761  	}
   762  	defer ws.Close()
   763  	sendRecv(t, ws)
   764  }
   765  
   766  func TestSocksProxyDial(t *testing.T) {
   767  	s := newServer(t)
   768  	defer s.Close()
   769  
   770  	proxyListener, err := net.Listen("tcp", "127.0.0.1:0")
   771  	if err != nil {
   772  		t.Fatalf("listen failed: %v", err)
   773  	}
   774  	defer proxyListener.Close()
   775  	go func() {
   776  		c1, err := proxyListener.Accept()
   777  		if err != nil {
   778  			t.Errorf("proxy accept failed: %v", err)
   779  			return
   780  		}
   781  		defer c1.Close()
   782  
   783  		c1.SetDeadline(time.Now().Add(30 * time.Second))
   784  
   785  		buf := make([]byte, 32)
   786  		if _, err := io.ReadFull(c1, buf[:3]); err != nil {
   787  			t.Errorf("read failed: %v", err)
   788  			return
   789  		}
   790  		if want := []byte{5, 1, 0}; !bytes.Equal(want, buf[:len(want)]) {
   791  			t.Errorf("read %x, want %x", buf[:len(want)], want)
   792  		}
   793  		if _, err := c1.Write([]byte{5, 0}); err != nil {
   794  			t.Errorf("write failed: %v", err)
   795  			return
   796  		}
   797  		if _, err := io.ReadFull(c1, buf[:10]); err != nil {
   798  			t.Errorf("read failed: %v", err)
   799  			return
   800  		}
   801  		if want := []byte{5, 1, 0, 1}; !bytes.Equal(want, buf[:len(want)]) {
   802  			t.Errorf("read %x, want %x", buf[:len(want)], want)
   803  			return
   804  		}
   805  		buf[1] = 0
   806  		if _, err := c1.Write(buf[:10]); err != nil {
   807  			t.Errorf("write failed: %v", err)
   808  			return
   809  		}
   810  
   811  		ip := net.IP(buf[4:8])
   812  		port := binary.BigEndian.Uint16(buf[8:10])
   813  
   814  		c2, err := net.DialTCP("tcp", nil, &net.TCPAddr{IP: ip, Port: int(port)})
   815  		if err != nil {
   816  			t.Errorf("dial failed; %v", err)
   817  			return
   818  		}
   819  		defer c2.Close()
   820  		done := make(chan struct{})
   821  		go func() {
   822  			io.Copy(c1, c2)
   823  			close(done)
   824  		}()
   825  		io.Copy(c2, c1)
   826  		<-done
   827  	}()
   828  
   829  	purl, err := url.Parse("socks5://" + proxyListener.Addr().String())
   830  	if err != nil {
   831  		t.Fatalf("parse failed: %v", err)
   832  	}
   833  
   834  	cstDialer := cstDialer // make local copy for modification on next line.
   835  	cstDialer.Proxy = http.ProxyURL(purl)
   836  
   837  	ws, _, err := cstDialer.Dial(s.URL, nil)
   838  	if err != nil {
   839  		t.Fatalf("Dial: %v", err)
   840  	}
   841  	defer ws.Close()
   842  	sendRecv(t, ws)
   843  }
   844  
   845  func TestTracingDialWithContext(t *testing.T) {
   846  
   847  	var headersWrote, requestWrote, getConn, gotConn, connectDone, gotFirstResponseByte bool
   848  	trace := &httptrace.ClientTrace{
   849  		WroteHeaders: func() {
   850  			headersWrote = true
   851  		},
   852  		WroteRequest: func(httptrace.WroteRequestInfo) {
   853  			requestWrote = true
   854  		},
   855  		GetConn: func(hostPort string) {
   856  			getConn = true
   857  		},
   858  		GotConn: func(info httptrace.GotConnInfo) {
   859  			gotConn = true
   860  		},
   861  		ConnectDone: func(network, addr string, err error) {
   862  			connectDone = true
   863  		},
   864  		GotFirstResponseByte: func() {
   865  			gotFirstResponseByte = true
   866  		},
   867  	}
   868  	ctx := httptrace.WithClientTrace(context.Background(), trace)
   869  
   870  	s := newTLSServer(t)
   871  	defer s.Close()
   872  
   873  	d := cstDialer
   874  	d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
   875  
   876  	ws, _, err := d.DialContext(ctx, s.URL, nil)
   877  	if err != nil {
   878  		t.Fatalf("Dial: %v", err)
   879  	}
   880  
   881  	if !headersWrote {
   882  		t.Fatal("Headers was not written")
   883  	}
   884  	if !requestWrote {
   885  		t.Fatal("Request was not written")
   886  	}
   887  	if !getConn {
   888  		t.Fatal("getConn was not called")
   889  	}
   890  	if !gotConn {
   891  		t.Fatal("gotConn was not called")
   892  	}
   893  	if !connectDone {
   894  		t.Fatal("connectDone was not called")
   895  	}
   896  	if !gotFirstResponseByte {
   897  		t.Fatal("GotFirstResponseByte was not called")
   898  	}
   899  
   900  	defer ws.Close()
   901  	sendRecv(t, ws)
   902  }
   903  
   904  func TestEmptyTracingDialWithContext(t *testing.T) {
   905  
   906  	trace := &httptrace.ClientTrace{}
   907  	ctx := httptrace.WithClientTrace(context.Background(), trace)
   908  
   909  	s := newTLSServer(t)
   910  	defer s.Close()
   911  
   912  	d := cstDialer
   913  	d.TLSClientConfig = &tls.Config{RootCAs: rootCAs(t, s.Server)}
   914  
   915  	ws, _, err := d.DialContext(ctx, s.URL, nil)
   916  	if err != nil {
   917  		t.Fatalf("Dial: %v", err)
   918  	}
   919  
   920  	defer ws.Close()
   921  	sendRecv(t, ws)
   922  }