github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/websocket/x/client.go (about)

     1  // Copyright 2009 The Go Authors. All rights reserved.
     2  // Use of this source code is governed by a BSD-style
     3  // license that can be found in the LICENSE file.
     4  
     5  package websocket
     6  
     7  import (
     8  	"bufio"
     9  	"crypto/rand"
    10  	"encoding/base64"
    11  	"fmt"
    12  	"io"
    13  	"net"
    14  	"net/http"
    15  	"strings"
    16  	_ "unsafe"
    17  
    18  	"github.com/Asutorufa/yuhaiin/pkg/net/netapi"
    19  )
    20  
    21  // Config is a WebSocket configuration
    22  type Config struct {
    23  	Host string
    24  	Path string
    25  
    26  	// A Websocket client origin.
    27  	OriginUrl string // eg: http://example.com/from/ws
    28  
    29  	// WebSocket subprotocols.
    30  	Protocol []string
    31  }
    32  
    33  // NewClient creates a new WebSocket client connection over rwc.
    34  func (config *Config) NewClient(SecWebSocketKey string, rwc net.Conn, request func(*http.Request) error, handshake func(*http.Response) error) (ws *Conn, err error) {
    35  	rwc, err = config.hybiClientHandshake(SecWebSocketKey, rwc, request, handshake)
    36  	if err != nil {
    37  		return
    38  	}
    39  	ws = newConn(rwc, false)
    40  	return
    41  }
    42  
    43  //go:linkname NewBufioReader net/http.newBufioReader
    44  func NewBufioReader(r io.Reader) *bufio.Reader
    45  
    46  //go:linkname PutBufioReader net/http.putBufioReader
    47  func PutBufioReader(br *bufio.Reader)
    48  
    49  //go:linkname newBufioWriterSize net/http.newBufioWriterSize
    50  func newBufioWriterSize(w io.Writer, size int) *bufio.Writer
    51  
    52  //go:linkname putBufioWriter net/http.putBufioWriter
    53  func putBufioWriter(br *bufio.Writer)
    54  
    55  // Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17
    56  func (config *Config) hybiClientHandshake(SecWebSocketKey string, conn net.Conn, request func(*http.Request) error, handshake func(*http.Response) error) (net.Conn, error) {
    57  	var nonce string
    58  	if SecWebSocketKey != "" {
    59  		nonce = SecWebSocketKey
    60  	} else {
    61  		nonce = generateNonce()
    62  	}
    63  
    64  	req, err := http.NewRequest(http.MethodGet, "http://"+config.Host+config.Path, http.NoBody)
    65  	if err != nil {
    66  		return nil, err
    67  	}
    68  	req.Header.Set("Upgrade", "websocket")
    69  	req.Header.Set("Connection", "Upgrade")
    70  	if config.OriginUrl != "" {
    71  		req.Header.Set("Origin", config.OriginUrl)
    72  	}
    73  	req.Header.Set("Sec-WebSocket-Key", nonce)
    74  	req.Header.Set("Sec-WebSocket-Version", SupportedProtocolVersion)
    75  	for _, p := range config.Protocol {
    76  		req.Header.Add("Sec-WebSocket-Protocol", p)
    77  	}
    78  	if request != nil {
    79  		if err := request(req); err != nil {
    80  			return nil, err
    81  		}
    82  	}
    83  	if err := req.Write(conn); err != nil {
    84  		return nil, err
    85  	}
    86  
    87  	reader := NewBufioReader(conn)
    88  	defer PutBufioReader(reader)
    89  
    90  	resp, err := http.ReadResponse(reader, req)
    91  	if err != nil {
    92  		return nil, err
    93  	}
    94  	if resp.StatusCode != http.StatusSwitchingProtocols {
    95  		return nil, ErrBadStatus
    96  	}
    97  	if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || strings.ToLower(resp.Header.Get("Connection")) != "upgrade" {
    98  		return nil, ErrBadUpgrade
    99  	}
   100  
   101  	if resp.Header.Get("Sec-WebSocket-Accept") != getNonceAccept(nonce) {
   102  		return nil, ErrChallengeResponse
   103  	}
   104  
   105  	if resp.Header.Get("Sec-WebSocket-Extensions") != "" {
   106  		return nil, ErrUnsupportedExtensions
   107  	}
   108  
   109  	if err = verifySubprotocol(config.Protocol, resp); err != nil {
   110  		return nil, err
   111  	}
   112  
   113  	if handshake != nil {
   114  		if err = handshake(resp); err != nil {
   115  			return nil, err
   116  		}
   117  	}
   118  
   119  	return netapi.MergeBufioReaderConn(conn, reader)
   120  }
   121  
   122  func verifySubprotocol(subprotos []string, resp *http.Response) error {
   123  	proto := resp.Header.Get("Sec-WebSocket-Protocol")
   124  	if proto == "" {
   125  		return nil
   126  	}
   127  
   128  	for _, sp2 := range subprotos {
   129  		if strings.EqualFold(sp2, proto) {
   130  			return nil
   131  		}
   132  	}
   133  
   134  	return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto)
   135  }
   136  
   137  // generateNonce generates a nonce consisting of a randomly selected 16-byte
   138  // value that has been base64-encoded.
   139  func generateNonce() string {
   140  	key := make([]byte, 16)
   141  	if _, err := io.ReadFull(rand.Reader, key); err != nil {
   142  		panic(err)
   143  	}
   144  	return base64.StdEncoding.EncodeToString(key)
   145  }