github.com/mdaxf/iac@v0.0.0-20240519030858-58a061660378/vendor_skip/nhooyr.io/websocket/accept.go (about) 1 // +build !js 2 3 package websocket 4 5 import ( 6 "bytes" 7 "crypto/sha1" 8 "encoding/base64" 9 "errors" 10 "fmt" 11 "io" 12 "log" 13 "net/http" 14 "net/textproto" 15 "net/url" 16 "path/filepath" 17 "strings" 18 19 "nhooyr.io/websocket/internal/errd" 20 ) 21 22 // AcceptOptions represents Accept's options. 23 type AcceptOptions struct { 24 // Subprotocols lists the WebSocket subprotocols that Accept will negotiate with the client. 25 // The empty subprotocol will always be negotiated as per RFC 6455. If you would like to 26 // reject it, close the connection when c.Subprotocol() == "". 27 Subprotocols []string 28 29 // InsecureSkipVerify is used to disable Accept's origin verification behaviour. 30 // 31 // You probably want to use OriginPatterns instead. 32 InsecureSkipVerify bool 33 34 // OriginPatterns lists the host patterns for authorized origins. 35 // The request host is always authorized. 36 // Use this to enable cross origin WebSockets. 37 // 38 // i.e javascript running on example.com wants to access a WebSocket server at chat.example.com. 39 // In such a case, example.com is the origin and chat.example.com is the request host. 40 // One would set this field to []string{"example.com"} to authorize example.com to connect. 41 // 42 // Each pattern is matched case insensitively against the request origin host 43 // with filepath.Match. 44 // See https://golang.org/pkg/path/filepath/#Match 45 // 46 // Please ensure you understand the ramifications of enabling this. 47 // If used incorrectly your WebSocket server will be open to CSRF attacks. 48 // 49 // Do not use * as a pattern to allow any origin, prefer to use InsecureSkipVerify instead 50 // to bring attention to the danger of such a setting. 51 OriginPatterns []string 52 53 // CompressionMode controls the compression mode. 54 // Defaults to CompressionNoContextTakeover. 55 // 56 // See docs on CompressionMode for details. 57 CompressionMode CompressionMode 58 59 // CompressionThreshold controls the minimum size of a message before compression is applied. 60 // 61 // Defaults to 512 bytes for CompressionNoContextTakeover and 128 bytes 62 // for CompressionContextTakeover. 63 CompressionThreshold int 64 } 65 66 // Accept accepts a WebSocket handshake from a client and upgrades the 67 // the connection to a WebSocket. 68 // 69 // Accept will not allow cross origin requests by default. 70 // See the InsecureSkipVerify and OriginPatterns options to allow cross origin requests. 71 // 72 // Accept will write a response to w on all errors. 73 func Accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (*Conn, error) { 74 return accept(w, r, opts) 75 } 76 77 func accept(w http.ResponseWriter, r *http.Request, opts *AcceptOptions) (_ *Conn, err error) { 78 defer errd.Wrap(&err, "failed to accept WebSocket connection") 79 80 if opts == nil { 81 opts = &AcceptOptions{} 82 } 83 opts = &*opts 84 85 errCode, err := verifyClientRequest(w, r) 86 if err != nil { 87 http.Error(w, err.Error(), errCode) 88 return nil, err 89 } 90 91 if !opts.InsecureSkipVerify { 92 err = authenticateOrigin(r, opts.OriginPatterns) 93 if err != nil { 94 if errors.Is(err, filepath.ErrBadPattern) { 95 log.Printf("websocket: %v", err) 96 err = errors.New(http.StatusText(http.StatusForbidden)) 97 } 98 http.Error(w, err.Error(), http.StatusForbidden) 99 return nil, err 100 } 101 } 102 103 hj, ok := w.(http.Hijacker) 104 if !ok { 105 err = errors.New("http.ResponseWriter does not implement http.Hijacker") 106 http.Error(w, http.StatusText(http.StatusNotImplemented), http.StatusNotImplemented) 107 return nil, err 108 } 109 110 w.Header().Set("Upgrade", "websocket") 111 w.Header().Set("Connection", "Upgrade") 112 113 key := r.Header.Get("Sec-WebSocket-Key") 114 w.Header().Set("Sec-WebSocket-Accept", secWebSocketAccept(key)) 115 116 subproto := selectSubprotocol(r, opts.Subprotocols) 117 if subproto != "" { 118 w.Header().Set("Sec-WebSocket-Protocol", subproto) 119 } 120 121 copts, err := acceptCompression(r, w, opts.CompressionMode) 122 if err != nil { 123 return nil, err 124 } 125 126 w.WriteHeader(http.StatusSwitchingProtocols) 127 // See https://github.com/nhooyr/websocket/issues/166 128 if ginWriter, ok := w.(interface { 129 WriteHeaderNow() 130 }); ok { 131 ginWriter.WriteHeaderNow() 132 } 133 134 netConn, brw, err := hj.Hijack() 135 if err != nil { 136 err = fmt.Errorf("failed to hijack connection: %w", err) 137 http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) 138 return nil, err 139 } 140 141 // https://github.com/golang/go/issues/32314 142 b, _ := brw.Reader.Peek(brw.Reader.Buffered()) 143 brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), netConn)) 144 145 return newConn(connConfig{ 146 subprotocol: w.Header().Get("Sec-WebSocket-Protocol"), 147 rwc: netConn, 148 client: false, 149 copts: copts, 150 flateThreshold: opts.CompressionThreshold, 151 152 br: brw.Reader, 153 bw: brw.Writer, 154 }), nil 155 } 156 157 func verifyClientRequest(w http.ResponseWriter, r *http.Request) (errCode int, _ error) { 158 if !r.ProtoAtLeast(1, 1) { 159 return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: handshake request must be at least HTTP/1.1: %q", r.Proto) 160 } 161 162 if !headerContainsTokenIgnoreCase(r.Header, "Connection", "Upgrade") { 163 w.Header().Set("Connection", "Upgrade") 164 w.Header().Set("Upgrade", "websocket") 165 return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Connection header %q does not contain Upgrade", r.Header.Get("Connection")) 166 } 167 168 if !headerContainsTokenIgnoreCase(r.Header, "Upgrade", "websocket") { 169 w.Header().Set("Connection", "Upgrade") 170 w.Header().Set("Upgrade", "websocket") 171 return http.StatusUpgradeRequired, fmt.Errorf("WebSocket protocol violation: Upgrade header %q does not contain websocket", r.Header.Get("Upgrade")) 172 } 173 174 if r.Method != "GET" { 175 return http.StatusMethodNotAllowed, fmt.Errorf("WebSocket protocol violation: handshake request method is not GET but %q", r.Method) 176 } 177 178 if r.Header.Get("Sec-WebSocket-Version") != "13" { 179 w.Header().Set("Sec-WebSocket-Version", "13") 180 return http.StatusBadRequest, fmt.Errorf("unsupported WebSocket protocol version (only 13 is supported): %q", r.Header.Get("Sec-WebSocket-Version")) 181 } 182 183 if r.Header.Get("Sec-WebSocket-Key") == "" { 184 return http.StatusBadRequest, errors.New("WebSocket protocol violation: missing Sec-WebSocket-Key") 185 } 186 187 return 0, nil 188 } 189 190 func authenticateOrigin(r *http.Request, originHosts []string) error { 191 origin := r.Header.Get("Origin") 192 if origin == "" { 193 return nil 194 } 195 196 u, err := url.Parse(origin) 197 if err != nil { 198 return fmt.Errorf("failed to parse Origin header %q: %w", origin, err) 199 } 200 201 if strings.EqualFold(r.Host, u.Host) { 202 return nil 203 } 204 205 for _, hostPattern := range originHosts { 206 matched, err := match(hostPattern, u.Host) 207 if err != nil { 208 return fmt.Errorf("failed to parse filepath pattern %q: %w", hostPattern, err) 209 } 210 if matched { 211 return nil 212 } 213 } 214 return fmt.Errorf("request Origin %q is not authorized for Host %q", origin, r.Host) 215 } 216 217 func match(pattern, s string) (bool, error) { 218 return filepath.Match(strings.ToLower(pattern), strings.ToLower(s)) 219 } 220 221 func selectSubprotocol(r *http.Request, subprotocols []string) string { 222 cps := headerTokens(r.Header, "Sec-WebSocket-Protocol") 223 for _, sp := range subprotocols { 224 for _, cp := range cps { 225 if strings.EqualFold(sp, cp) { 226 return cp 227 } 228 } 229 } 230 return "" 231 } 232 233 func acceptCompression(r *http.Request, w http.ResponseWriter, mode CompressionMode) (*compressionOptions, error) { 234 if mode == CompressionDisabled { 235 return nil, nil 236 } 237 238 for _, ext := range websocketExtensions(r.Header) { 239 switch ext.name { 240 case "permessage-deflate": 241 return acceptDeflate(w, ext, mode) 242 // Disabled for now, see https://github.com/nhooyr/websocket/issues/218 243 // case "x-webkit-deflate-frame": 244 // return acceptWebkitDeflate(w, ext, mode) 245 } 246 } 247 return nil, nil 248 } 249 250 func acceptDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { 251 copts := mode.opts() 252 253 for _, p := range ext.params { 254 switch p { 255 case "client_no_context_takeover": 256 copts.clientNoContextTakeover = true 257 continue 258 case "server_no_context_takeover": 259 copts.serverNoContextTakeover = true 260 continue 261 } 262 263 if strings.HasPrefix(p, "client_max_window_bits") { 264 // We cannot adjust the read sliding window so cannot make use of this. 265 continue 266 } 267 268 err := fmt.Errorf("unsupported permessage-deflate parameter: %q", p) 269 http.Error(w, err.Error(), http.StatusBadRequest) 270 return nil, err 271 } 272 273 copts.setHeader(w.Header()) 274 275 return copts, nil 276 } 277 278 func acceptWebkitDeflate(w http.ResponseWriter, ext websocketExtension, mode CompressionMode) (*compressionOptions, error) { 279 copts := mode.opts() 280 // The peer must explicitly request it. 281 copts.serverNoContextTakeover = false 282 283 for _, p := range ext.params { 284 if p == "no_context_takeover" { 285 copts.serverNoContextTakeover = true 286 continue 287 } 288 289 // We explicitly fail on x-webkit-deflate-frame's max_window_bits parameter instead 290 // of ignoring it as the draft spec is unclear. It says the server can ignore it 291 // but the server has no way of signalling to the client it was ignored as the parameters 292 // are set one way. 293 // Thus us ignoring it would make the client think we understood it which would cause issues. 294 // See https://tools.ietf.org/html/draft-tyoshino-hybi-websocket-perframe-deflate-06#section-4.1 295 // 296 // Either way, we're only implementing this for webkit which never sends the max_window_bits 297 // parameter so we don't need to worry about it. 298 err := fmt.Errorf("unsupported x-webkit-deflate-frame parameter: %q", p) 299 http.Error(w, err.Error(), http.StatusBadRequest) 300 return nil, err 301 } 302 303 s := "x-webkit-deflate-frame" 304 if copts.clientNoContextTakeover { 305 s += "; no_context_takeover" 306 } 307 w.Header().Set("Sec-WebSocket-Extensions", s) 308 309 return copts, nil 310 } 311 312 func headerContainsTokenIgnoreCase(h http.Header, key, token string) bool { 313 for _, t := range headerTokens(h, key) { 314 if strings.EqualFold(t, token) { 315 return true 316 } 317 } 318 return false 319 } 320 321 type websocketExtension struct { 322 name string 323 params []string 324 } 325 326 func websocketExtensions(h http.Header) []websocketExtension { 327 var exts []websocketExtension 328 extStrs := headerTokens(h, "Sec-WebSocket-Extensions") 329 for _, extStr := range extStrs { 330 if extStr == "" { 331 continue 332 } 333 334 vals := strings.Split(extStr, ";") 335 for i := range vals { 336 vals[i] = strings.TrimSpace(vals[i]) 337 } 338 339 e := websocketExtension{ 340 name: vals[0], 341 params: vals[1:], 342 } 343 344 exts = append(exts, e) 345 } 346 return exts 347 } 348 349 func headerTokens(h http.Header, key string) []string { 350 key = textproto.CanonicalMIMEHeaderKey(key) 351 var tokens []string 352 for _, v := range h[key] { 353 v = strings.TrimSpace(v) 354 for _, t := range strings.Split(v, ",") { 355 t = strings.TrimSpace(t) 356 tokens = append(tokens, t) 357 } 358 } 359 return tokens 360 } 361 362 var keyGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 363 364 func secWebSocketAccept(secWebSocketKey string) string { 365 h := sha1.New() 366 h.Write([]byte(secWebSocketKey)) 367 h.Write(keyGUID) 368 369 return base64.StdEncoding.EncodeToString(h.Sum(nil)) 370 }