github.com/Asutorufa/yuhaiin@v0.3.6-0.20240502055049-7984da7023a0/pkg/net/proxy/websocket/x/client.go (about) 1 // Copyright 2009 The Go Authors. All rights reserved. 2 // Use of this source code is governed by a BSD-style 3 // license that can be found in the LICENSE file. 4 5 package websocket 6 7 import ( 8 "bufio" 9 "crypto/rand" 10 "encoding/base64" 11 "fmt" 12 "io" 13 "net" 14 "net/http" 15 "strings" 16 _ "unsafe" 17 18 "github.com/Asutorufa/yuhaiin/pkg/net/netapi" 19 ) 20 21 // Config is a WebSocket configuration 22 type Config struct { 23 Host string 24 Path string 25 26 // A Websocket client origin. 27 OriginUrl string // eg: http://example.com/from/ws 28 29 // WebSocket subprotocols. 30 Protocol []string 31 } 32 33 // NewClient creates a new WebSocket client connection over rwc. 34 func (config *Config) NewClient(SecWebSocketKey string, rwc net.Conn, request func(*http.Request) error, handshake func(*http.Response) error) (ws *Conn, err error) { 35 rwc, err = config.hybiClientHandshake(SecWebSocketKey, rwc, request, handshake) 36 if err != nil { 37 return 38 } 39 ws = newConn(rwc, false) 40 return 41 } 42 43 //go:linkname NewBufioReader net/http.newBufioReader 44 func NewBufioReader(r io.Reader) *bufio.Reader 45 46 //go:linkname PutBufioReader net/http.putBufioReader 47 func PutBufioReader(br *bufio.Reader) 48 49 //go:linkname newBufioWriterSize net/http.newBufioWriterSize 50 func newBufioWriterSize(w io.Writer, size int) *bufio.Writer 51 52 //go:linkname putBufioWriter net/http.putBufioWriter 53 func putBufioWriter(br *bufio.Writer) 54 55 // Client handshake described in draft-ietf-hybi-thewebsocket-protocol-17 56 func (config *Config) hybiClientHandshake(SecWebSocketKey string, conn net.Conn, request func(*http.Request) error, handshake func(*http.Response) error) (net.Conn, error) { 57 var nonce string 58 if SecWebSocketKey != "" { 59 nonce = SecWebSocketKey 60 } else { 61 nonce = generateNonce() 62 } 63 64 req, err := http.NewRequest(http.MethodGet, "http://"+config.Host+config.Path, http.NoBody) 65 if err != nil { 66 return nil, err 67 } 68 req.Header.Set("Upgrade", "websocket") 69 req.Header.Set("Connection", "Upgrade") 70 if config.OriginUrl != "" { 71 req.Header.Set("Origin", config.OriginUrl) 72 } 73 req.Header.Set("Sec-WebSocket-Key", nonce) 74 req.Header.Set("Sec-WebSocket-Version", SupportedProtocolVersion) 75 for _, p := range config.Protocol { 76 req.Header.Add("Sec-WebSocket-Protocol", p) 77 } 78 if request != nil { 79 if err := request(req); err != nil { 80 return nil, err 81 } 82 } 83 if err := req.Write(conn); err != nil { 84 return nil, err 85 } 86 87 reader := NewBufioReader(conn) 88 defer PutBufioReader(reader) 89 90 resp, err := http.ReadResponse(reader, req) 91 if err != nil { 92 return nil, err 93 } 94 if resp.StatusCode != http.StatusSwitchingProtocols { 95 return nil, ErrBadStatus 96 } 97 if strings.ToLower(resp.Header.Get("Upgrade")) != "websocket" || strings.ToLower(resp.Header.Get("Connection")) != "upgrade" { 98 return nil, ErrBadUpgrade 99 } 100 101 if resp.Header.Get("Sec-WebSocket-Accept") != getNonceAccept(nonce) { 102 return nil, ErrChallengeResponse 103 } 104 105 if resp.Header.Get("Sec-WebSocket-Extensions") != "" { 106 return nil, ErrUnsupportedExtensions 107 } 108 109 if err = verifySubprotocol(config.Protocol, resp); err != nil { 110 return nil, err 111 } 112 113 if handshake != nil { 114 if err = handshake(resp); err != nil { 115 return nil, err 116 } 117 } 118 119 return netapi.MergeBufioReaderConn(conn, reader) 120 } 121 122 func verifySubprotocol(subprotos []string, resp *http.Response) error { 123 proto := resp.Header.Get("Sec-WebSocket-Protocol") 124 if proto == "" { 125 return nil 126 } 127 128 for _, sp2 := range subprotos { 129 if strings.EqualFold(sp2, proto) { 130 return nil 131 } 132 } 133 134 return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) 135 } 136 137 // generateNonce generates a nonce consisting of a randomly selected 16-byte 138 // value that has been base64-encoded. 139 func generateNonce() string { 140 key := make([]byte, 16) 141 if _, err := io.ReadFull(rand.Reader, key); err != nil { 142 panic(err) 143 } 144 return base64.StdEncoding.EncodeToString(key) 145 }