github.com/insionng/yougam@v0.0.0-20170714101924-2bc18d833463/libraries/gorilla/websocket/client.go (about)

     1  // Copyright 2013 The Gorilla WebSocket Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bufio"
     9  	"bytes"
    10  	"crypto/tls"
    11  	"encoding/base64"
    12  	"errors"
    13  	"io"
    14  	"io/ioutil"
    15  	"net"
    16  	"net/http"
    17  	"net/url"
    18  	"strings"
    19  	"time"
    20  )
    21  
    22  // ErrBadHandshake is returned when the server response to opening handshake is
    23  // invalid.
    24  var ErrBadHandshake = errors.New("websocket: bad handshake")
    25  
    26  // NewClient creates a new client connection using the given net connection.
    27  // The URL u specifies the host and request URI. Use requestHeader to specify
    28  // the origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies
    29  // (Cookie). Use the response.Header to get the selected subprotocol
    30  // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
    31  //
    32  // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
    33  // non-nil *http.Response so that callers can handle redirects, authentication,
    34  // etc.
    35  //
    36  // Deprecated: Use Dialer instead.
    37  func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) {
    38  	d := Dialer{
    39  		ReadBufferSize:  readBufSize,
    40  		WriteBufferSize: writeBufSize,
    41  		NetDial: func(net, addr string) (net.Conn, error) {
    42  			return netConn, nil
    43  		},
    44  	}
    45  	return d.Dial(u.String(), requestHeader)
    46  }
    47  
    48  // A Dialer contains options for connecting to WebSocket server.
    49  type Dialer struct {
    50  	// NetDial specifies the dial function for creating TCP connections. If
    51  	// NetDial is nil, net.Dial is used.
    52  	NetDial func(network, addr string) (net.Conn, error)
    53  
    54  	// Proxy specifies a function to return a proxy for a given
    55  	// Request. If the function returns a non-nil error, the
    56  	// request is aborted with the provided error.
    57  	// If Proxy is nil or returns a nil *URL, no proxy is used.
    58  	Proxy func(*http.Request) (*url.URL, error)
    59  
    60  	// TLSClientConfig specifies the TLS configuration to use with tls.Client.
    61  	// If nil, the default configuration is used.
    62  	TLSClientConfig *tls.Config
    63  
    64  	// HandshakeTimeout specifies the duration for the handshake to complete.
    65  	HandshakeTimeout time.Duration
    66  
    67  	// Input and output buffer sizes. If the buffer size is zero, then a
    68  	// default value of 4096 is used.
    69  	ReadBufferSize, WriteBufferSize int
    70  
    71  	// Subprotocols specifies the client's requested subprotocols.
    72  	Subprotocols []string
    73  }
    74  
    75  var errMalformedURL = errors.New("malformed ws or wss URL")
    76  
    77  // parseURL parses the URL.
    78  //
    79  // This function is a replacement for the standard library url.Parse function.
    80  // In Go 1.4 and earlier, url.Parse loses information from the path.
    81  func parseURL(s string) (*url.URL, error) {
    82  	// From the RFC:
    83  	//
    84  	// ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ]
    85  	// wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ]
    86  
    87  	var u url.URL
    88  	switch {
    89  	case strings.HasPrefix(s, "ws://"):
    90  		u.Scheme = "ws"
    91  		s = s[len("ws://"):]
    92  	case strings.HasPrefix(s, "wss://"):
    93  		u.Scheme = "wss"
    94  		s = s[len("wss://"):]
    95  	default:
    96  		return nil, errMalformedURL
    97  	}
    98  
    99  	if i := strings.Index(s, "?"); i >= 0 {
   100  		u.RawQuery = s[i+1:]
   101  		s = s[:i]
   102  	}
   103  
   104  	if i := strings.Index(s, "/"); i >= 0 {
   105  		u.Opaque = s[i:]
   106  		s = s[:i]
   107  	} else {
   108  		u.Opaque = "/"
   109  	}
   110  
   111  	u.Host = s
   112  
   113  	if strings.Contains(u.Host, "@") {
   114  		// Don't bother parsing user information because user information is
   115  		// not allowed in websocket URIs.
   116  		return nil, errMalformedURL
   117  	}
   118  
   119  	return &u, nil
   120  }
   121  
   122  func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) {
   123  	hostPort = u.Host
   124  	hostNoPort = u.Host
   125  	if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") {
   126  		hostNoPort = hostNoPort[:i]
   127  	} else {
   128  		switch u.Scheme {
   129  		case "wss":
   130  			hostPort += ":443"
   131  		case "https":
   132  			hostPort += ":443"
   133  		default:
   134  			hostPort += ":80"
   135  		}
   136  	}
   137  	return hostPort, hostNoPort
   138  }
   139  
   140  // DefaultDialer is a dialer with all fields set to the default zero values.
   141  var DefaultDialer = &Dialer{
   142  	Proxy: http.ProxyFromEnvironment,
   143  }
   144  
   145  // Dial creates a new client connection. Use requestHeader to specify the
   146  // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie).
   147  // Use the response.Header to get the selected subprotocol
   148  // (Sec-WebSocket-Protocol) and cookies (Set-Cookie).
   149  //
   150  // If the WebSocket handshake fails, ErrBadHandshake is returned along with a
   151  // non-nil *http.Response so that callers can handle redirects, authentication,
   152  // etcetera. The response body may not contain the entire response and does not
   153  // need to be closed by the application.
   154  func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) {
   155  
   156  	if d == nil {
   157  		d = &Dialer{
   158  			Proxy: http.ProxyFromEnvironment,
   159  		}
   160  	}
   161  
   162  	challengeKey, err := generateChallengeKey()
   163  	if err != nil {
   164  		return nil, nil, err
   165  	}
   166  
   167  	u, err := parseURL(urlStr)
   168  	if err != nil {
   169  		return nil, nil, err
   170  	}
   171  
   172  	switch u.Scheme {
   173  	case "ws":
   174  		u.Scheme = "http"
   175  	case "wss":
   176  		u.Scheme = "https"
   177  	default:
   178  		return nil, nil, errMalformedURL
   179  	}
   180  
   181  	if u.User != nil {
   182  		// User name and password are not allowed in websocket URIs.
   183  		return nil, nil, errMalformedURL
   184  	}
   185  
   186  	req := &http.Request{
   187  		Method:     "GET",
   188  		URL:        u,
   189  		Proto:      "HTTP/1.1",
   190  		ProtoMajor: 1,
   191  		ProtoMinor: 1,
   192  		Header:     make(http.Header),
   193  		Host:       u.Host,
   194  	}
   195  
   196  	// Set the request headers using the capitalization for names and values in
   197  	// RFC examples. Although the capitalization shouldn't matter, there are
   198  	// servers that depend on it. The Header.Set method is not used because the
   199  	// method canonicalizes the header names.
   200  	req.Header["Upgrade"] = []string{"websocket"}
   201  	req.Header["Connection"] = []string{"Upgrade"}
   202  	req.Header["Sec-WebSocket-Key"] = []string{challengeKey}
   203  	req.Header["Sec-WebSocket-Version"] = []string{"13"}
   204  	if len(d.Subprotocols) > 0 {
   205  		req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")}
   206  	}
   207  	for k, vs := range requestHeader {
   208  		switch {
   209  		case k == "Host":
   210  			if len(vs) > 0 {
   211  				req.Host = vs[0]
   212  			}
   213  		case k == "Upgrade" ||
   214  			k == "Connection" ||
   215  			k == "Sec-Websocket-Key" ||
   216  			k == "Sec-Websocket-Version" ||
   217  			(k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0):
   218  			return nil, nil, errors.New("websocket: duplicate header not allowed: " + k)
   219  		default:
   220  			req.Header[k] = vs
   221  		}
   222  	}
   223  
   224  	hostPort, hostNoPort := hostPortNoPort(u)
   225  
   226  	var proxyURL *url.URL
   227  	// Check wether the proxy method has been configured
   228  	if d.Proxy != nil {
   229  		proxyURL, err = d.Proxy(req)
   230  	}
   231  	if err != nil {
   232  		return nil, nil, err
   233  	}
   234  
   235  	var targetHostPort string
   236  	if proxyURL != nil {
   237  		targetHostPort, _ = hostPortNoPort(proxyURL)
   238  	} else {
   239  		targetHostPort = hostPort
   240  	}
   241  
   242  	var deadline time.Time
   243  	if d.HandshakeTimeout != 0 {
   244  		deadline = time.Now().Add(d.HandshakeTimeout)
   245  	}
   246  
   247  	netDial := d.NetDial
   248  	if netDial == nil {
   249  		netDialer := &net.Dialer{Deadline: deadline}
   250  		netDial = netDialer.Dial
   251  	}
   252  
   253  	netConn, err := netDial("tcp", targetHostPort)
   254  	if err != nil {
   255  		return nil, nil, err
   256  	}
   257  
   258  	defer func() {
   259  		if netConn != nil {
   260  			netConn.Close()
   261  		}
   262  	}()
   263  
   264  	if err := netConn.SetDeadline(deadline); err != nil {
   265  		return nil, nil, err
   266  	}
   267  
   268  	if proxyURL != nil {
   269  		connectHeader := make(http.Header)
   270  		if user := proxyURL.User; user != nil {
   271  			proxyUser := user.Username()
   272  			if proxyPassword, passwordSet := user.Password(); passwordSet {
   273  				credential := base64.StdEncoding.EncodeToString([]byte(proxyUser + ":" + proxyPassword))
   274  				connectHeader.Set("Proxy-Authorization", "Basic "+credential)
   275  			}
   276  		}
   277  		connectReq := &http.Request{
   278  			Method: "CONNECT",
   279  			URL:    &url.URL{Opaque: hostPort},
   280  			Host:   hostPort,
   281  			Header: connectHeader,
   282  		}
   283  
   284  		connectReq.Write(netConn)
   285  
   286  		// Read response.
   287  		// Okay to use and discard buffered reader here, because
   288  		// TLS server will not speak until spoken to.
   289  		br := bufio.NewReader(netConn)
   290  		resp, err := http.ReadResponse(br, connectReq)
   291  		if err != nil {
   292  			return nil, nil, err
   293  		}
   294  		if resp.StatusCode != 200 {
   295  			f := strings.SplitN(resp.Status, " ", 2)
   296  			return nil, nil, errors.New(f[1])
   297  		}
   298  	}
   299  
   300  	if u.Scheme == "https" {
   301  		cfg := d.TLSClientConfig
   302  		if cfg == nil {
   303  			cfg = &tls.Config{ServerName: hostNoPort}
   304  		} else if cfg.ServerName == "" {
   305  			shallowCopy := *cfg
   306  			cfg = &shallowCopy
   307  			cfg.ServerName = hostNoPort
   308  		}
   309  		tlsConn := tls.Client(netConn, cfg)
   310  		netConn = tlsConn
   311  		if err := tlsConn.Handshake(); err != nil {
   312  			return nil, nil, err
   313  		}
   314  		if !cfg.InsecureSkipVerify {
   315  			if err := tlsConn.VerifyHostname(cfg.ServerName); err != nil {
   316  				return nil, nil, err
   317  			}
   318  		}
   319  	}
   320  
   321  	conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize)
   322  
   323  	if err := req.Write(netConn); err != nil {
   324  		return nil, nil, err
   325  	}
   326  
   327  	resp, err := http.ReadResponse(conn.br, req)
   328  	if err != nil {
   329  		return nil, nil, err
   330  	}
   331  	if resp.StatusCode != 101 ||
   332  		!strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") ||
   333  		!strings.EqualFold(resp.Header.Get("Connection"), "upgrade") ||
   334  		resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) {
   335  		// Before closing the network connection on return from this
   336  		// function, slurp up some of the response to aid application
   337  		// debugging.
   338  		buf := make([]byte, 1024)
   339  		n, _ := io.ReadFull(resp.Body, buf)
   340  		resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n]))
   341  		return nil, resp, ErrBadHandshake
   342  	}
   343  
   344  	resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{}))
   345  	conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol")
   346  
   347  	netConn.SetDeadline(time.Time{})
   348  	netConn = nil // to avoid close in defer.
   349  	return conn, resp, nil
   350  }