github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/accept.go (about)

     1  // +build !js
     2  
     3  package websocket
     4  
     5  import (
     6  	"bytes"
     7  	"crypto/sha1"
     8  	"encoding/base64"
     9  	"errors"
    10  	"fmt"
    11  	"io"
    12  	"log"
    13  	"net/http"
    14  	"net/textproto"
    15  	"net/url"
    16  	"path/filepath"
    17  	"strings"
    18  
    19  	"nhooyr.io/websocket/internal/errd"
    20  )
    21  
    22  // AcceptOptions represents Accept's options.
    23  type AcceptOptions struct {
    24  	// Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client.
    25  	// The empty subprotocol will always be negotiated as per RFC 6455. If you would like to
    26  	// reject it, close the connection when c.Subprotocol() == "".
    27  	Subprotocols []string
    28  
    29  	// InsecureSkipVerify is used to disable Accept's origin verification behaviour.
    30  	//
    31  	// You probably want to use OriginPatterns instead.
    32  	InsecureSkipVerify bool
    33  
    34  	// OriginPatterns lists the host patterns for authorized origins.
    35  	// The request host is always authorized.
    36  	// Use this to enable cross origin WebSockets.
    37  	//
    38  	// i.e javascript running on example.com wants to access a WebSocket server at chat.example.com.
    39  	// In such a case, example.com is the origin and chat.example.com is the request host.
    40  	// One would set this field to []string{"example.com"} to authorize example.com to connect.
    41  	//
    42  	// Each pattern is matched case insensitively against the request origin host
    43  	// with filepath.Match.
    44  	// See https://golang.org/pkg/path/filepath/#Match
    45  	//
    46  	// Please ensure you understand the ramifications of enabling this.
    47  	// If used incorrectly your WebSocket server will be open to CSRF attacks.
    48  	//
    49  	// Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead
    50  	// to bring attention to the danger of such a setting.
    51  	OriginPatterns []string
    52  
    53  	// CompressionMode controls the compression mode.
    54  	// Defaults to CompressionNoContextTakeover.
    55  	//
    56  	// See docs on CompressionMode for details.
    57  	CompressionMode CompressionMode
    58  
    59  	// CompressionThreshold controls the minimum size of a message before compression is applied.
    60  	//
    61  	// Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes
    62  	// for CompressionContextTakeover.
    63  	CompressionThreshold int
    64  }
    65  
    66  // Accept accepts a WebSocket handshake from a client and upgrades the
    67  // the connection to a WebSocket.
    68  //
    69  // Accept will not allow cross origin requests by default.
    70  // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests.
    71  //
    72  // Accept will write a response to w on all errors.
    73  func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) {
    74  	return accept(w, r, opts)
    75  }
    76  
    77  func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) {
    78  	defer errd.Wrap(&err, "failed to accept WebSocket connection")
    79  
    80  	if opts == nil {
    81  		opts = &AcceptOptions{}
    82  	}
    83  	opts = &*opts
    84  
    85  	errCode, err := verifyClientRequest(w, r)
    86  	if err != nil {
    87  		http.Error(w, err.Error(), errCode)
    88  		return nil, err
    89  	}
    90  
    91  	if !opts.InsecureSkipVerify {
    92  		err = authenticateOrigin(r, opts.OriginPatterns)
    93  		if err != nil {
    94  			if errors.Is(err, filepath.ErrBadPattern) {
    95  				log.Printf("websocket: %v", err)
    96  				err = errors.New(http.StatusText(http.StatusForbidden))
    97  			}
    98  			http.Error(w, err.Error(), http.StatusForbidden)
    99  			return nil, err
   100  		}
   101  	}
   102  
   103  	hj, ok := w.(http.Hijacker)
   104  	if !ok {
   105  		err = errors.New("http.ResponseWriter does not implement http.Hijacker")
   106  		http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented)
   107  		return nil, err
   108  	}
   109  
   110  	w.Header().Set("Upgrade", "websocket")
   111  	w.Header().Set("Connection", "Upgrade")
   112  
   113  	key := r.Header.Get("Sec-WebSocket-Key")
   114  	w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key))
   115  
   116  	subproto := selectSubprotocol(r, opts.Subprotocols)
   117  	if subproto != "" {
   118  		w.Header().Set("Sec-WebSocket-Protocol", subproto)
   119  	}
   120  
   121  	copts, err := acceptCompression(r, w, opts.CompressionMode)
   122  	if err != nil {
   123  		return nil, err
   124  	}
   125  
   126  	w.WriteHeader(http.StatusSwitchingProtocols)
   127  	// See https://github.com/nhooyr/websocket/issues/166
   128  	if ginWriter, ok := w.(interface {
   129  		WriteHeaderNow()
   130  	}); ok {
   131  		ginWriter.WriteHeaderNow()
   132  	}
   133  
   134  	netConn, brw, err := hj.Hijack()
   135  	if err != nil {
   136  		err = fmt.Errorf("failed to hijack connection: %w", err)
   137  		http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
   138  		return nil, err
   139  	}
   140  
   141  	// https://github.com/golang/go/issues/32314
   142  	b, _ := brw.Reader.Peek(brw.Reader.Buffered())
   143  	brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn))
   144  
   145  	return newConn(connConfig{
   146  		subprotocol:    w.Header().Get("Sec-WebSocket-Protocol"),
   147  		rwc:            netConn,
   148  		client:         false,
   149  		copts:          copts,
   150  		flateThreshold: opts.CompressionThreshold,
   151  
   152  		br: brw.Reader,
   153  		bw: brw.Writer,
   154  	}), nil
   155  }
   156  
   157  func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) {
   158  	if !r.ProtoAtLeast(1, 1) {
   159  		return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto)
   160  	}
   161  
   162  	if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") {
   163  		w.Header().Set("Connection", "Upgrade")
   164  		w.Header().Set("Upgrade", "websocket")
   165  		return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection"))
   166  	}
   167  
   168  	if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") {
   169  		w.Header().Set("Connection", "Upgrade")
   170  		w.Header().Set("Upgrade", "websocket")
   171  		return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade"))
   172  	}
   173  
   174  	if r.Method != "GET" {
   175  		return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method)
   176  	}
   177  
   178  	if r.Header.Get("Sec-WebSocket-Version") != "13" {
   179  		w.Header().Set("Sec-WebSocket-Version", "13")
   180  		return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version"))
   181  	}
   182  
   183  	if r.Header.Get("Sec-WebSocket-Key") == "" {
   184  		return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key")
   185  	}
   186  
   187  	return 0, nil
   188  }
   189  
   190  func authenticateOrigin(r *http.Request, originHosts []string) error {
   191  	origin := r.Header.Get("Origin")
   192  	if origin == "" {
   193  		return nil
   194  	}
   195  
   196  	u, err := url.Parse(origin)
   197  	if err != nil {
   198  		return fmt.Errorf("failed to parse Origin header %q: %w", origin, err)
   199  	}
   200  
   201  	if strings.EqualFold(r.Host, u.Host) {
   202  		return nil
   203  	}
   204  
   205  	for _, hostPattern := range originHosts {
   206  		matched, err := match(hostPattern, u.Host)
   207  		if err != nil {
   208  			return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err)
   209  		}
   210  		if matched {
   211  			return nil
   212  		}
   213  	}
   214  	return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host)
   215  }
   216  
   217  func match(pattern, s string) (bool, error) {
   218  	return filepath.Match(strings.ToLower(pattern), strings.ToLower(s))
   219  }
   220  
   221  func selectSubprotocol(r *http.Request, subprotocols []string) string {
   222  	cps := headerTokens(r.Header, "Sec-WebSocket-Protocol")
   223  	for _, sp := range subprotocols {
   224  		for _, cp := range cps {
   225  			if strings.EqualFold(sp, cp) {
   226  				return cp
   227  			}
   228  		}
   229  	}
   230  	return ""
   231  }
   232  
   233  func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) {
   234  	if mode == CompressionDisabled {
   235  		return nil, nil
   236  	}
   237  
   238  	for _, ext := range websocketExtensions(r.Header) {
   239  		switch ext.name {
   240  		case "permessage-deflate":
   241  			return acceptDeflate(w, ext, mode)
   242  			// Disabled for now, see https://github.com/nhooyr/websocket/issues/218
   243  			// case "x-webkit-deflate-frame":
   244  			// 	return acceptWebkitDeflate(w, ext, mode)
   245  		}
   246  	}
   247  	return nil, nil
   248  }
   249  
   250  func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
   251  	copts := mode.opts()
   252  
   253  	for _, p := range ext.params {
   254  		switch p {
   255  		case "client_no_context_takeover":
   256  			copts.clientNoContextTakeover = true
   257  			continue
   258  		case "server_no_context_takeover":
   259  			copts.serverNoContextTakeover = true
   260  			continue
   261  		}
   262  
   263  		if strings.HasPrefix(p, "client_max_window_bits") {
   264  			// We cannot adjust the read sliding window so cannot make use of this.
   265  			continue
   266  		}
   267  
   268  		err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p)
   269  		http.Error(w, err.Error(), http.StatusBadRequest)
   270  		return nil, err
   271  	}
   272  
   273  	copts.setHeader(w.Header())
   274  
   275  	return copts, nil
   276  }
   277  
   278  func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) {
   279  	copts := mode.opts()
   280  	// The peer must explicitly request it.
   281  	copts.serverNoContextTakeover = false
   282  
   283  	for _, p := range ext.params {
   284  		if p == "no_context_takeover" {
   285  			copts.serverNoContextTakeover = true
   286  			continue
   287  		}
   288  
   289  		// We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead
   290  		// of ignoring it as the draft spec is unclear. It says the server can ignore it
   291  		// but the server has no way of signalling to the client it was ignored as the parameters
   292  		// are set one way.
   293  		// Thus us ignoring it would make the client think we understood it which would cause issues.
   294  		// See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1
   295  		//
   296  		// Either way, we're only implementing this for webkit which never sends the max_window_bits
   297  		// parameter so we don't need to worry about it.
   298  		err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p)
   299  		http.Error(w, err.Error(), http.StatusBadRequest)
   300  		return nil, err
   301  	}
   302  
   303  	s := "x-webkit-deflate-frame"
   304  	if copts.clientNoContextTakeover {
   305  		s += "; no_context_takeover"
   306  	}
   307  	w.Header().Set("Sec-WebSocket-Extensions", s)
   308  
   309  	return copts, nil
   310  }
   311  
   312  func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool {
   313  	for _, t := range headerTokens(h, key) {
   314  		if strings.EqualFold(t, token) {
   315  			return true
   316  		}
   317  	}
   318  	return false
   319  }
   320  
   321  type websocketExtension struct {
   322  	name   string
   323  	params []string
   324  }
   325  
   326  func websocketExtensions(h http.Header) []websocketExtension {
   327  	var exts []websocketExtension
   328  	extStrs := headerTokens(h, "Sec-WebSocket-Extensions")
   329  	for _, extStr := range extStrs {
   330  		if extStr == "" {
   331  			continue
   332  		}
   333  
   334  		vals := strings.Split(extStr, ";")
   335  		for i := range vals {
   336  			vals[i] = strings.TrimSpace(vals[i])
   337  		}
   338  
   339  		e := websocketExtension{
   340  			name:   vals[0],
   341  			params: vals[1:],
   342  		}
   343  
   344  		exts = append(exts, e)
   345  	}
   346  	return exts
   347  }
   348  
   349  func headerTokens(h http.Header, key string) []string {
   350  	key = textproto.CanonicalMIMEHeaderKey(key)
   351  	var tokens []string
   352  	for _, v := range h[key] {
   353  		v = strings.TrimSpace(v)
   354  		for _, t := range strings.Split(v, ",") {
   355  			t = strings.TrimSpace(t)
   356  			tokens = append(tokens, t)
   357  		}
   358  	}
   359  	return tokens
   360  }
   361  
   362  var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
   363  
   364  func secWebSocketAccept(secWebSocketKey string) string {
   365  	h := sha1.New()
   366  	h.Write([]byte(secWebSocketKey))
   367  	h.Write(keyGUID)
   368  
   369  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
   370  }