github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/dial.go (about) 1 // +build !js 2 3 package websocket 4 5 import ( 6 "bufio" 7 "bytes" 8 "context" 9 "crypto/rand" 10 "encoding/base64" 11 "fmt" 12 "io" 13 "io/ioutil" 14 "net/http" 15 "net/url" 16 "strings" 17 "sync" 18 "time" 19 20 "nhooyr.io/websocket/internal/errd" 21 ) 22 23 // DialOptions represents Dial's options. 24 type DialOptions struct { 25 // HTTPClient is used for the connection. 26 // Its Transport must return writable bodies for WebSocket handshakes. 27 // http.Transport does beginning with Go 1.12. 28 HTTPClient *http.Client 29 30 // HTTPHeader specifies the HTTP headers included in the handshake request. 31 HTTPHeader http.Header 32 33 // Subprotocols lists the WebSocket subprotocols to negotiate with the server. 34 Subprotocols []string 35 36 // CompressionMode controls the compression mode. 37 // Defaults to CompressionNoContextTakeover. 38 // 39 // See docs on CompressionMode for details. 40 CompressionMode CompressionMode 41 42 // CompressionThreshold controls the minimum size of a message before compression is applied. 43 // 44 // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes 45 // for CompressionContextTakeover. 46 CompressionThreshold int 47 } 48 49 // Dial performs a WebSocket handshake on url. 50 // 51 // The response is the WebSocket handshake response from the server. 52 // You never need to close resp.Body yourself. 53 // 54 // If an error occurs, the returned response may be non nil. 55 // However, you can only read the first 1024 bytes of the body. 56 // 57 // This function requires at least Go 1.12 as it uses a new feature 58 // in net/http to perform WebSocket handshakes. 59 // See docs on the HTTPClient option and https://github.com/golang/go/issues/26937#issuecomment-415855861 60 // 61 // URLs with http/https schemes will work and are interpreted as ws/wss. 62 func Dial(ctx context.Context, u string, opts *DialOptions) (*Conn, *http.Response, error) { 63 return dial(ctx, u, opts, nil) 64 } 65 66 func dial(ctx context.Context, urls string, opts *DialOptions, rand io.Reader) (_ *Conn, _ *http.Response, err error) { 67 defer errd.Wrap(&err, "failed to WebSocket dial") 68 69 if opts == nil { 70 opts = &DialOptions{} 71 } 72 73 opts = &*opts 74 if opts.HTTPClient == nil { 75 opts.HTTPClient = http.DefaultClient 76 } else if opts.HTTPClient.Timeout > 0 { 77 var cancel context.CancelFunc 78 79 ctx, cancel = context.WithTimeout(ctx, opts.HTTPClient.Timeout) 80 defer cancel() 81 82 newClient := *opts.HTTPClient 83 newClient.Timeout = 0 84 opts.HTTPClient = &newClient 85 } 86 87 if opts.HTTPHeader == nil { 88 opts.HTTPHeader = http.Header{} 89 } 90 91 secWebSocketKey, err := secWebSocketKey(rand) 92 if err != nil { 93 return nil, nil, fmt.Errorf("failed to generate Sec-WebSocket-Key: %w", err) 94 } 95 96 var copts *compressionOptions 97 if opts.CompressionMode != CompressionDisabled { 98 copts = opts.CompressionMode.opts() 99 } 100 101 resp, err := handshakeRequest(ctx, urls, opts, copts, secWebSocketKey) 102 if err != nil { 103 return nil, resp, err 104 } 105 respBody := resp.Body 106 resp.Body = nil 107 defer func() { 108 if err != nil { 109 // We read a bit of the body for easier debugging. 110 r := io.LimitReader(respBody, 1024) 111 112 timer := time.AfterFunc(time.Second*3, func() { 113 respBody.Close() 114 }) 115 defer timer.Stop() 116 117 b, _ := ioutil.ReadAll(r) 118 respBody.Close() 119 resp.Body = ioutil.NopCloser(bytes.NewReader(b)) 120 } 121 }() 122 123 copts, err = verifyServerResponse(opts, copts, secWebSocketKey, resp) 124 if err != nil { 125 return nil, resp, err 126 } 127 128 rwc, ok := respBody.(io.ReadWriteCloser) 129 if !ok { 130 return nil, resp, fmt.Errorf("response body is not a io.ReadWriteCloser: %T", respBody) 131 } 132 133 return newConn(connConfig{ 134 subprotocol: resp.Header.Get("Sec-WebSocket-Protocol"), 135 rwc: rwc, 136 client: true, 137 copts: copts, 138 flateThreshold: opts.CompressionThreshold, 139 br: getBufioReader(rwc), 140 bw: getBufioWriter(rwc), 141 }), resp, nil 142 } 143 144 func handshakeRequest(ctx context.Context, urls string, opts *DialOptions, copts *compressionOptions, secWebSocketKey string) (*http.Response, error) { 145 u, err := url.Parse(urls) 146 if err != nil { 147 return nil, fmt.Errorf("failed to parse url: %w", err) 148 } 149 150 switch u.Scheme { 151 case "ws": 152 u.Scheme = "http" 153 case "wss": 154 u.Scheme = "https" 155 case "http", "https": 156 default: 157 return nil, fmt.Errorf("unexpected url scheme: %q", u.Scheme) 158 } 159 160 req, _ := http.NewRequestWithContext(ctx, "GET", u.String(), nil) 161 req.Header = opts.HTTPHeader.Clone() 162 req.Header.Set("Connection", "Upgrade") 163 req.Header.Set("Upgrade", "websocket") 164 req.Header.Set("Sec-WebSocket-Version", "13") 165 req.Header.Set("Sec-WebSocket-Key", secWebSocketKey) 166 if len(opts.Subprotocols) > 0 { 167 req.Header.Set("Sec-WebSocket-Protocol", strings.Join(opts.Subprotocols, ",")) 168 } 169 if copts != nil { 170 copts.setHeader(req.Header) 171 } 172 173 resp, err := opts.HTTPClient.Do(req) 174 if err != nil { 175 return nil, fmt.Errorf("failed to send handshake request: %w", err) 176 } 177 return resp, nil 178 } 179 180 func secWebSocketKey(rr io.Reader) (string, error) { 181 if rr == nil { 182 rr = rand.Reader 183 } 184 b := make([]byte, 16) 185 _, err := io.ReadFull(rr, b) 186 if err != nil { 187 return "", fmt.Errorf("failed to read random data from rand.Reader: %w", err) 188 } 189 return base64.StdEncoding.EncodeToString(b), nil 190 } 191 192 func verifyServerResponse(opts *DialOptions, copts *compressionOptions, secWebSocketKey string, resp *http.Response) (*compressionOptions, error) { 193 if resp.StatusCode != http.StatusSwitchingProtocols { 194 return nil, fmt.Errorf("expected handshake response status code %v but got %v", http.StatusSwitchingProtocols, resp.StatusCode) 195 } 196 197 if !headerContainsTokenIgnoreCase(resp.Header, "Connection", "Upgrade") { 198 return nil, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", resp.Header.Get("Connection")) 199 } 200 201 if !headerContainsTokenIgnoreCase(resp.Header, "Upgrade", "WebSocket") { 202 return nil, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", resp.Header.Get("Upgrade")) 203 } 204 205 if resp.Header.Get("Sec-WebSocket-Accept") != secWebSocketAccept(secWebSocketKey) { 206 return nil, fmt.Errorf("WebSocket protocol violation: invalid Sec-WebSocket-Accept %q, key %q", 207 resp.Header.Get("Sec-WebSocket-Accept"), 208 secWebSocketKey, 209 ) 210 } 211 212 err := verifySubprotocol(opts.Subprotocols, resp) 213 if err != nil { 214 return nil, err 215 } 216 217 return verifyServerExtensions(copts, resp.Header) 218 } 219 220 func verifySubprotocol(subprotos []string, resp *http.Response) error { 221 proto := resp.Header.Get("Sec-WebSocket-Protocol") 222 if proto == "" { 223 return nil 224 } 225 226 for _, sp2 := range subprotos { 227 if strings.EqualFold(sp2, proto) { 228 return nil 229 } 230 } 231 232 return fmt.Errorf("WebSocket protocol violation: unexpected Sec-WebSocket-Protocol from server: %q", proto) 233 } 234 235 func verifyServerExtensions(copts *compressionOptions, h http.Header) (*compressionOptions, error) { 236 exts := websocketExtensions(h) 237 if len(exts) == 0 { 238 return nil, nil 239 } 240 241 ext := exts[0] 242 if ext.name != "permessage-deflate" || len(exts) > 1 || copts == nil { 243 return nil, fmt.Errorf("WebSocket protcol violation: unsupported extensions from server: %+v", exts[1:]) 244 } 245 246 copts = &*copts 247 248 for _, p := range ext.params { 249 switch p { 250 case "client_no_context_takeover": 251 copts.clientNoContextTakeover = true 252 continue 253 case "server_no_context_takeover": 254 copts.serverNoContextTakeover = true 255 continue 256 } 257 258 return nil, fmt.Errorf("unsupported permessage-deflate parameter: %q", p) 259 } 260 261 return copts, nil 262 } 263 264 var bufioReaderPool sync.Pool 265 266 func getBufioReader(r io.Reader) *bufio.Reader { 267 br, ok := bufioReaderPool.Get().(*bufio.Reader) 268 if !ok { 269 return bufio.NewReader(r) 270 } 271 br.Reset(r) 272 return br 273 } 274 275 func putBufioReader(br *bufio.Reader) { 276 bufioReaderPool.Put(br) 277 } 278 279 var bufioWriterPool sync.Pool 280 281 func getBufioWriter(w io.Writer) *bufio.Writer { 282 bw, ok := bufioWriterPool.Get().(*bufio.Writer) 283 if !ok { 284 return bufio.NewWriter(w) 285 } 286 bw.Reset(w) 287 return bw 288 } 289 290 func putBufioWriter(bw *bufio.Writer) { 291 bufioWriterPool.Put(bw) 292 }