get.pme.sh/pnats@v0.0.0-20240304004023-26bb5a137ed0/server/websocket.go (about) 1 // Copyright 2020-2024 The NATS Authors 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package server 15 16 import ( 17 "bytes" 18 crand "crypto/rand" 19 "crypto/sha1" 20 "crypto/tls" 21 "encoding/base64" 22 "encoding/binary" 23 "errors" 24 "fmt" 25 "io" 26 "log" 27 mrand "math/rand" 28 "net" 29 "net/http" 30 "net/url" 31 "strconv" 32 "strings" 33 "sync" 34 "time" 35 "unicode/utf8" 36 37 "github.com/klauspost/compress/flate" 38 ) 39 40 type wsOpCode int 41 42 const ( 43 // From https://tools.ietf.org/html/rfc6455#section-5.2 44 wsTextMessage = wsOpCode(1) 45 wsBinaryMessage = wsOpCode(2) 46 wsCloseMessage = wsOpCode(8) 47 wsPingMessage = wsOpCode(9) 48 wsPongMessage = wsOpCode(10) 49 50 wsFinalBit = 1 << 7 51 wsRsv1Bit = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6 52 wsRsv2Bit = 1 << 5 53 wsRsv3Bit = 1 << 4 54 55 wsMaskBit = 1 << 7 56 57 wsContinuationFrame = 0 58 wsMaxFrameHeaderSize = 14 // Since LeafNode may need to behave as a client 59 wsMaxControlPayloadSize = 125 60 wsFrameSizeForBrowsers = 4096 // From experiment, webrowsers behave better with limited frame size 61 wsCompressThreshold = 64 // Don't compress for small buffer(s) 62 wsCloseSatusSize = 2 63 64 // From https://tools.ietf.org/html/rfc6455#section-11.7 65 wsCloseStatusNormalClosure = 1000 66 wsCloseStatusGoingAway = 1001 67 wsCloseStatusProtocolError = 1002 68 wsCloseStatusUnsupportedData = 1003 69 wsCloseStatusNoStatusReceived = 1005 70 wsCloseStatusAbnormalClosure = 1006 71 wsCloseStatusInvalidPayloadData = 1007 72 wsCloseStatusPolicyViolation = 1008 73 wsCloseStatusMessageTooBig = 1009 74 wsCloseStatusInternalSrvError = 1011 75 wsCloseStatusTLSHandshake = 1015 76 77 wsFirstFrame = true 78 wsContFrame = false 79 wsFinalFrame = true 80 wsUncompressedFrame = false 81 82 wsSchemePrefix = "ws" 83 wsSchemePrefixTLS = "wss" 84 85 wsNoMaskingHeader = "Nats-No-Masking" 86 wsNoMaskingValue = "true" 87 wsXForwardedForHeader = "X-Forwarded-For" 88 wsNoMaskingFullResponse = wsNoMaskingHeader + ": " + wsNoMaskingValue + CR_LF 89 wsPMCExtension = "permessage-deflate" // per-message compression 90 wsPMCSrvNoCtx = "server_no_context_takeover" 91 wsPMCCliNoCtx = "client_no_context_takeover" 92 wsPMCReqHeaderValue = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx 93 wsPMCFullResponse = "Sec-WebSocket-Extensions: " + wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx + _CRLF_ 94 wsSecProto = "Sec-Websocket-Protocol" 95 wsMQTTSecProtoVal = "mqtt" 96 wsMQTTSecProto = wsSecProto + ": " + wsMQTTSecProtoVal + CR_LF 97 ) 98 99 var decompressorPool sync.Pool 100 var compressLastBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff} 101 102 // From https://tools.ietf.org/html/rfc6455#section-1.3 103 var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 104 105 // Test can enable this so that server does not support "no-masking" requests. 106 var wsTestRejectNoMasking = false 107 108 type websocket struct { 109 frames net.Buffers 110 fs int64 111 closeMsg []byte 112 compress bool 113 closeSent bool 114 browser bool 115 nocompfrag bool // No fragment for compressed frames 116 maskread bool 117 maskwrite bool 118 compressor *flate.Writer 119 cookieJwt string 120 cookieUsername string 121 cookiePassword string 122 cookieToken string 123 clientIP string 124 } 125 126 type srvWebsocket struct { 127 mu sync.RWMutex 128 server *http.Server 129 listener net.Listener 130 listenerErr error 131 tls bool 132 allowedOrigins map[string]*allowedOrigin // host will be the key 133 sameOrigin bool 134 connectURLs []string 135 connectURLsMap refCountedUrlSet 136 authOverride bool // indicate if there is auth override in websocket config 137 } 138 139 type allowedOrigin struct { 140 scheme string 141 port string 142 } 143 144 type wsUpgradeResult struct { 145 conn net.Conn 146 ws *websocket 147 kind int 148 } 149 150 type wsReadInfo struct { 151 rem int 152 fs bool 153 ff bool 154 fc bool 155 mask bool // Incoming leafnode connections may not have masking. 156 mkpos byte 157 mkey [4]byte 158 cbufs [][]byte 159 coff int 160 } 161 162 func (r *wsReadInfo) init() { 163 r.fs, r.ff = true, true 164 } 165 166 // Returns a slice containing `needed` bytes from the given buffer `buf` 167 // starting at position `pos`, and possibly read from the given reader `r`. 168 // When bytes are present in `buf`, the `pos` is incremented by the number 169 // of bytes found up to `needed` and the new position is returned. If not 170 // enough bytes are found, the bytes found in `buf` are copied to the returned 171 // slice and the remaning bytes are read from `r`. 172 func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) { 173 avail := len(buf) - pos 174 if avail >= needed { 175 return buf[pos : pos+needed], pos + needed, nil 176 } 177 b := make([]byte, needed) 178 start := copy(b, buf[pos:]) 179 for start != needed { 180 n, err := r.Read(b[start:cap(b)]) 181 if err != nil { 182 return nil, 0, err 183 } 184 start += n 185 } 186 return b, pos + avail, nil 187 } 188 189 // Returns true if this connection is from a Websocket client. 190 // Lock held on entry. 191 func (c *client) isWebsocket() bool { 192 return c.ws != nil 193 } 194 195 // Returns a slice of byte slices corresponding to payload of websocket frames. 196 // The byte slice `buf` is filled with bytes from the connection's read loop. 197 // This function will decode the frame headers and unmask the payload(s). 198 // It is possible that the returned slices point to the given `buf` slice, so 199 // `buf` should not be overwritten until the returned slices have been parsed. 200 // 201 // Client lock MUST NOT be held on entry. 202 func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, error) { 203 var ( 204 bufs [][]byte 205 tmpBuf []byte 206 err error 207 pos int 208 max = len(buf) 209 ) 210 for pos != max { 211 if r.fs { 212 b0 := buf[pos] 213 frameType := wsOpCode(b0 & 0xF) 214 final := b0&wsFinalBit != 0 215 compressed := b0&wsRsv1Bit != 0 216 pos++ 217 218 tmpBuf, pos, err = wsGet(ior, buf, pos, 1) 219 if err != nil { 220 return bufs, err 221 } 222 b1 := tmpBuf[0] 223 224 // Clients MUST set the mask bit. If not set, reject. 225 // However, LEAF by default will not have masking, unless they are forced to, by configuration. 226 if r.mask && b1&wsMaskBit == 0 { 227 return bufs, c.wsHandleProtocolError("mask bit missing") 228 } 229 230 // Store size in case it is < 125 231 r.rem = int(b1 & 0x7F) 232 233 switch frameType { 234 case wsPingMessage, wsPongMessage, wsCloseMessage: 235 if r.rem > wsMaxControlPayloadSize { 236 return bufs, c.wsHandleProtocolError( 237 fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes", 238 wsMaxControlPayloadSize)) 239 } 240 if !final { 241 return bufs, c.wsHandleProtocolError("control frame does not have final bit set") 242 } 243 case wsTextMessage, wsBinaryMessage: 244 if !r.ff { 245 return bufs, c.wsHandleProtocolError("new message started before final frame for previous message was received") 246 } 247 r.ff = final 248 r.fc = compressed 249 case wsContinuationFrame: 250 // Compressed bit must be only set in the first frame 251 if r.ff || compressed { 252 return bufs, c.wsHandleProtocolError("invalid continuation frame") 253 } 254 r.ff = final 255 default: 256 return bufs, c.wsHandleProtocolError(fmt.Sprintf("unknown opcode %v", frameType)) 257 } 258 259 switch r.rem { 260 case 126: 261 tmpBuf, pos, err = wsGet(ior, buf, pos, 2) 262 if err != nil { 263 return bufs, err 264 } 265 r.rem = int(binary.BigEndian.Uint16(tmpBuf)) 266 case 127: 267 tmpBuf, pos, err = wsGet(ior, buf, pos, 8) 268 if err != nil { 269 return bufs, err 270 } 271 r.rem = int(binary.BigEndian.Uint64(tmpBuf)) 272 } 273 274 if r.mask { 275 // Read masking key 276 tmpBuf, pos, err = wsGet(ior, buf, pos, 4) 277 if err != nil { 278 return bufs, err 279 } 280 copy(r.mkey[:], tmpBuf) 281 r.mkpos = 0 282 } 283 284 // Handle control messages in place... 285 if wsIsControlFrame(frameType) { 286 pos, err = c.wsHandleControlFrame(r, frameType, ior, buf, pos) 287 if err != nil { 288 return bufs, err 289 } 290 continue 291 } 292 293 // Done with the frame header 294 r.fs = false 295 } 296 if pos < max { 297 var b []byte 298 var n int 299 300 n = r.rem 301 if pos+n > max { 302 n = max - pos 303 } 304 b = buf[pos : pos+n] 305 pos += n 306 r.rem -= n 307 // If needed, unmask the buffer 308 if r.mask { 309 r.unmask(b) 310 } 311 addToBufs := true 312 // Handle compressed message 313 if r.fc { 314 // Assume that we may have continuation frames or not the full payload. 315 addToBufs = false 316 // Make a copy of the buffer before adding it to the list 317 // of compressed fragments. 318 r.cbufs = append(r.cbufs, append([]byte(nil), b...)) 319 // When we have the final frame and we have read the full payload, 320 // we can decompress it. 321 if r.ff && r.rem == 0 { 322 b, err = r.decompress() 323 if err != nil { 324 return bufs, err 325 } 326 r.fc = false 327 // Now we can add to `bufs` 328 addToBufs = true 329 } 330 } 331 // For non compressed frames, or when we have decompressed the 332 // whole message. 333 if addToBufs { 334 bufs = append(bufs, b) 335 } 336 // If payload has been fully read, then indicate that next 337 // is the start of a frame. 338 if r.rem == 0 { 339 r.fs = true 340 } 341 } 342 } 343 return bufs, nil 344 } 345 346 func (r *wsReadInfo) Read(dst []byte) (int, error) { 347 if len(dst) == 0 { 348 return 0, nil 349 } 350 if len(r.cbufs) == 0 { 351 return 0, io.EOF 352 } 353 copied := 0 354 rem := len(dst) 355 for buf := r.cbufs[0]; buf != nil && rem > 0; { 356 n := len(buf[r.coff:]) 357 if n > rem { 358 n = rem 359 } 360 copy(dst[copied:], buf[r.coff:r.coff+n]) 361 copied += n 362 rem -= n 363 r.coff += n 364 buf = r.nextCBuf() 365 } 366 return copied, nil 367 } 368 369 func (r *wsReadInfo) nextCBuf() []byte { 370 // We still have remaining data in the first buffer 371 if r.coff != len(r.cbufs[0]) { 372 return r.cbufs[0] 373 } 374 // We read the full first buffer. Reset offset. 375 r.coff = 0 376 // We were at the last buffer, so we are done. 377 if len(r.cbufs) == 1 { 378 r.cbufs = nil 379 return nil 380 } 381 // Here we move to the next buffer. 382 r.cbufs = r.cbufs[1:] 383 return r.cbufs[0] 384 } 385 386 func (r *wsReadInfo) ReadByte() (byte, error) { 387 if len(r.cbufs) == 0 { 388 return 0, io.EOF 389 } 390 b := r.cbufs[0][r.coff] 391 r.coff++ 392 r.nextCBuf() 393 return b, nil 394 } 395 396 func (r *wsReadInfo) decompress() ([]byte, error) { 397 r.coff = 0 398 // As per https://tools.ietf.org/html/rfc7692#section-7.2.2 399 // add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader 400 // does not report unexpected EOF. 401 r.cbufs = append(r.cbufs, compressLastBlock) 402 // Get a decompressor from the pool and bind it to this object (wsReadInfo) 403 // that provides Read() and ReadByte() APIs that will consume the compressed 404 // buffers (r.cbufs). 405 d, _ := decompressorPool.Get().(io.ReadCloser) 406 if d == nil { 407 d = flate.NewReader(r) 408 } else { 409 d.(flate.Resetter).Reset(r, nil) 410 } 411 // This will do the decompression. 412 b, err := io.ReadAll(d) 413 decompressorPool.Put(d) 414 // Now reset the compressed buffers list. 415 r.cbufs = nil 416 return b, err 417 } 418 419 // Handles the PING, PONG and CLOSE websocket control frames. 420 // 421 // Client lock MUST NOT be held on entry. 422 func (c *client) wsHandleControlFrame(r *wsReadInfo, frameType wsOpCode, nc io.Reader, buf []byte, pos int) (int, error) { 423 var payload []byte 424 var err error 425 426 if r.rem > 0 { 427 payload, pos, err = wsGet(nc, buf, pos, r.rem) 428 if err != nil { 429 return pos, err 430 } 431 if r.mask { 432 r.unmask(payload) 433 } 434 r.rem = 0 435 } 436 switch frameType { 437 case wsCloseMessage: 438 status := wsCloseStatusNoStatusReceived 439 var body string 440 lp := len(payload) 441 // If there is a payload, the status is represented as a 2-byte 442 // unsigned integer (in network byte order). Then, there may be an 443 // optional body. 444 hasStatus, hasBody := lp >= wsCloseSatusSize, lp > wsCloseSatusSize 445 if hasStatus { 446 // Decode the status 447 status = int(binary.BigEndian.Uint16(payload[:wsCloseSatusSize])) 448 // Now if there is a body, capture it and make sure this is a valid UTF-8. 449 if hasBody { 450 body = string(payload[wsCloseSatusSize:]) 451 if !utf8.ValidString(body) { 452 // https://tools.ietf.org/html/rfc6455#section-5.5.1 453 // If body is present, it must be a valid utf8 454 status = wsCloseStatusInvalidPayloadData 455 body = "invalid utf8 body in close frame" 456 } 457 } 458 } 459 clm := wsCreateCloseMessage(status, body) 460 c.wsEnqueueControlMessage(wsCloseMessage, clm) 461 nbPoolPut(clm) // wsEnqueueControlMessage has taken a copy. 462 // Return io.EOF so that readLoop will close the connection as ClientClosed 463 // after processing pending buffers. 464 return pos, io.EOF 465 case wsPingMessage: 466 c.wsEnqueueControlMessage(wsPongMessage, payload) 467 case wsPongMessage: 468 // Nothing to do.. 469 } 470 return pos, nil 471 } 472 473 // Unmask the given slice. 474 func (r *wsReadInfo) unmask(buf []byte) { 475 p := int(r.mkpos) 476 if len(buf) < 16 { 477 for i := 0; i < len(buf); i++ { 478 buf[i] ^= r.mkey[p&3] 479 p++ 480 } 481 r.mkpos = byte(p & 3) 482 return 483 } 484 var k [8]byte 485 for i := 0; i < 8; i++ { 486 k[i] = r.mkey[(p+i)&3] 487 } 488 km := binary.BigEndian.Uint64(k[:]) 489 n := (len(buf) / 8) * 8 490 for i := 0; i < n; i += 8 { 491 tmp := binary.BigEndian.Uint64(buf[i : i+8]) 492 tmp ^= km 493 binary.BigEndian.PutUint64(buf[i:], tmp) 494 } 495 buf = buf[n:] 496 for i := 0; i < len(buf); i++ { 497 buf[i] ^= r.mkey[p&3] 498 p++ 499 } 500 r.mkpos = byte(p & 3) 501 } 502 503 // Returns true if the op code corresponds to a control frame. 504 func wsIsControlFrame(frameType wsOpCode) bool { 505 return frameType >= wsCloseMessage 506 } 507 508 // Create the frame header. 509 // Encodes the frame type and optional compression flag, and the size of the payload. 510 func wsCreateFrameHeader(useMasking, compressed bool, frameType wsOpCode, l int) ([]byte, []byte) { 511 fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize] 512 n, key := wsFillFrameHeader(fh, useMasking, wsFirstFrame, wsFinalFrame, compressed, frameType, l) 513 return fh[:n], key 514 } 515 516 func wsFillFrameHeader(fh []byte, useMasking, first, final, compressed bool, frameType wsOpCode, l int) (int, []byte) { 517 var n int 518 var b byte 519 if first { 520 b = byte(frameType) 521 } 522 if final { 523 b |= wsFinalBit 524 } 525 if compressed { 526 b |= wsRsv1Bit 527 } 528 b1 := byte(0) 529 if useMasking { 530 b1 |= wsMaskBit 531 } 532 switch { 533 case l <= 125: 534 n = 2 535 fh[0] = b 536 fh[1] = b1 | byte(l) 537 case l < 65536: 538 n = 4 539 fh[0] = b 540 fh[1] = b1 | 126 541 binary.BigEndian.PutUint16(fh[2:], uint16(l)) 542 default: 543 n = 10 544 fh[0] = b 545 fh[1] = b1 | 127 546 binary.BigEndian.PutUint64(fh[2:], uint64(l)) 547 } 548 var key []byte 549 if useMasking { 550 var keyBuf [4]byte 551 if _, err := io.ReadFull(crand.Reader, keyBuf[:4]); err != nil { 552 kv := mrand.Int31() 553 binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv)) 554 } 555 copy(fh[n:], keyBuf[:4]) 556 key = fh[n : n+4] 557 n += 4 558 } 559 return n, key 560 } 561 562 // Invokes wsEnqueueControlMessageLocked under client lock. 563 // 564 // Client lock MUST NOT be held on entry 565 func (c *client) wsEnqueueControlMessage(controlMsg wsOpCode, payload []byte) { 566 c.mu.Lock() 567 c.wsEnqueueControlMessageLocked(controlMsg, payload) 568 c.mu.Unlock() 569 } 570 571 // Mask the buffer with the given key 572 func wsMaskBuf(key, buf []byte) { 573 for i := 0; i < len(buf); i++ { 574 buf[i] ^= key[i&3] 575 } 576 } 577 578 // Mask the buffers, as if they were contiguous, with the given key 579 func wsMaskBufs(key []byte, bufs [][]byte) { 580 pos := 0 581 for i := 0; i < len(bufs); i++ { 582 buf := bufs[i] 583 for j := 0; j < len(buf); j++ { 584 buf[j] ^= key[pos&3] 585 pos++ 586 } 587 } 588 } 589 590 // Enqueues a websocket control message. 591 // If the control message is a wsCloseMessage, then marks this client 592 // has having sent the close message (since only one should be sent). 593 // This will prevent the generic closeConnection() to enqueue one. 594 // 595 // Client lock held on entry. 596 func (c *client) wsEnqueueControlMessageLocked(controlMsg wsOpCode, payload []byte) { 597 // Control messages are never compressed and their size will be 598 // less than wsMaxControlPayloadSize, which means the frame header 599 // will be only 2 or 6 bytes. 600 useMasking := c.ws.maskwrite 601 sz := 2 602 if useMasking { 603 sz += 4 604 } 605 cm := nbPoolGet(sz + len(payload)) 606 cm = cm[:cap(cm)] 607 n, key := wsFillFrameHeader(cm, useMasking, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, controlMsg, len(payload)) 608 cm = cm[:n] 609 // Note that payload is optional. 610 if len(payload) > 0 { 611 cm = append(cm, payload...) 612 if useMasking { 613 wsMaskBuf(key, cm[n:]) 614 } 615 } 616 c.out.pb += int64(len(cm)) 617 if controlMsg == wsCloseMessage { 618 // We can't add the close message to the frames buffers 619 // now. It will be done on a flushOutbound() when there 620 // are no more pending buffers to send. 621 c.ws.closeSent = true 622 c.ws.closeMsg = cm 623 } else { 624 c.ws.frames = append(c.ws.frames, cm) 625 c.ws.fs += int64(len(cm)) 626 } 627 c.flushSignal() 628 } 629 630 // Enqueues a websocket close message with a status mapped from the given `reason`. 631 // 632 // Client lock held on entry 633 func (c *client) wsEnqueueCloseMessage(reason ClosedState) { 634 var status int 635 switch reason { 636 case ClientClosed: 637 status = wsCloseStatusNormalClosure 638 case AuthenticationTimeout, AuthenticationViolation, SlowConsumerPendingBytes, SlowConsumerWriteDeadline, 639 MaxAccountConnectionsExceeded, MaxConnectionsExceeded, MaxControlLineExceeded, MaxSubscriptionsExceeded, 640 MissingAccount, AuthenticationExpired, Revocation: 641 status = wsCloseStatusPolicyViolation 642 case TLSHandshakeError: 643 status = wsCloseStatusTLSHandshake 644 case ParseError, ProtocolViolation, BadClientProtocolVersion: 645 status = wsCloseStatusProtocolError 646 case MaxPayloadExceeded: 647 status = wsCloseStatusMessageTooBig 648 case ServerShutdown: 649 status = wsCloseStatusGoingAway 650 case WriteError, ReadError, StaleConnection: 651 status = wsCloseStatusAbnormalClosure 652 default: 653 status = wsCloseStatusInternalSrvError 654 } 655 body := wsCreateCloseMessage(status, reason.String()) 656 c.wsEnqueueControlMessageLocked(wsCloseMessage, body) 657 nbPoolPut(body) // wsEnqueueControlMessageLocked has taken a copy. 658 } 659 660 // Create and then enqueue a close message with a protocol error and the 661 // given message. This is invoked when parsing websocket frames. 662 // 663 // Lock MUST NOT be held on entry. 664 func (c *client) wsHandleProtocolError(message string) error { 665 buf := wsCreateCloseMessage(wsCloseStatusProtocolError, message) 666 c.wsEnqueueControlMessage(wsCloseMessage, buf) 667 nbPoolPut(buf) // wsEnqueueControlMessage has taken a copy. 668 return fmt.Errorf(message) 669 } 670 671 // Create a close message with the given `status` and `body`. 672 // If the `body` is more than the maximum allows control frame payload size, 673 // it is truncated and "..." is added at the end (as a hint that message 674 // is not complete). 675 func wsCreateCloseMessage(status int, body string) []byte { 676 // Since a control message payload is limited in size, we 677 // will limit the text and add trailing "..." if truncated. 678 // The body of a Close Message must be preceded with 2 bytes, 679 // so take that into account for limiting the body length. 680 if len(body) > wsMaxControlPayloadSize-2 { 681 body = body[:wsMaxControlPayloadSize-5] 682 body += "..." 683 } 684 buf := nbPoolGet(2 + len(body))[:2+len(body)] 685 // We need to have a 2 byte unsigned int that represents the error status code 686 // https://tools.ietf.org/html/rfc6455#section-5.5.1 687 binary.BigEndian.PutUint16(buf[:2], uint16(status)) 688 copy(buf[2:], []byte(body)) 689 return buf 690 } 691 692 // Process websocket client handshake. On success, returns the raw net.Conn that 693 // will be used to create a *client object. 694 // Invoked from the HTTP server listening on websocket port. 695 func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeResult, error) { 696 kind := CLIENT 697 if r.URL != nil { 698 ep := r.URL.EscapedPath() 699 if strings.HasSuffix(ep, leafNodeWSPath) { 700 kind = LEAF 701 } else if strings.HasSuffix(ep, mqttWSPath) { 702 kind = MQTT 703 } 704 } 705 706 opts := s.getOpts() 707 708 // From https://tools.ietf.org/html/rfc6455#section-4.2.1 709 // Point 1. 710 if r.Method != "GET" { 711 return nil, wsReturnHTTPError(w, r, http.StatusMethodNotAllowed, "request method must be GET") 712 } 713 // Point 2. 714 if r.Host == _EMPTY_ { 715 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "'Host' missing in request") 716 } 717 // Point 3. 718 if !wsHeaderContains(r.Header, "Upgrade", "websocket") { 719 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Upgrade'") 720 } 721 // Point 4. 722 if !wsHeaderContains(r.Header, "Connection", "Upgrade") { 723 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Connection'") 724 } 725 // Point 5. 726 key := r.Header.Get("Sec-Websocket-Key") 727 if key == _EMPTY_ { 728 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "key missing") 729 } 730 // Point 6. 731 if !wsHeaderContains(r.Header, "Sec-Websocket-Version", "13") { 732 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid version") 733 } 734 // Others are optional 735 // Point 7. 736 if err := s.websocket.checkOrigin(r); err != nil { 737 return nil, wsReturnHTTPError(w, r, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err)) 738 } 739 // Point 8. 740 // We don't have protocols, so ignore. 741 // Point 9. 742 // Extensions, only support for compression at the moment 743 compress := opts.Websocket.Compression 744 if compress { 745 // Simply check if permessage-deflate extension is present. 746 compress, _ = wsPMCExtensionSupport(r.Header, true) 747 } 748 // We will do masking if asked (unless we reject for tests) 749 noMasking := r.Header.Get(wsNoMaskingHeader) == wsNoMaskingValue && !wsTestRejectNoMasking 750 751 h := w.(http.Hijacker) 752 conn, brw, err := h.Hijack() 753 if err != nil { 754 if conn != nil { 755 conn.Close() 756 } 757 return nil, wsReturnHTTPError(w, r, http.StatusInternalServerError, err.Error()) 758 } 759 if brw.Reader.Buffered() > 0 { 760 conn.Close() 761 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "client sent data before handshake is complete") 762 } 763 764 var buf [1024]byte 765 p := buf[:0] 766 767 // From https://tools.ietf.org/html/rfc6455#section-4.2.2 768 p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) 769 p = append(p, wsAcceptKey(key)...) 770 p = append(p, _CRLF_...) 771 if compress { 772 p = append(p, wsPMCFullResponse...) 773 } 774 if noMasking { 775 p = append(p, wsNoMaskingFullResponse...) 776 } 777 if kind == MQTT { 778 p = append(p, wsMQTTSecProto...) 779 } 780 p = append(p, _CRLF_...) 781 782 if _, err = conn.Write(p); err != nil { 783 conn.Close() 784 return nil, err 785 } 786 // If there was a deadline set for the handshake, clear it now. 787 if opts.Websocket.HandshakeTimeout > 0 { 788 conn.SetDeadline(time.Time{}) 789 } 790 // Server always expect "clients" to send masked payload, unless the option 791 // "no-masking" has been enabled. 792 ws := &websocket{compress: compress, maskread: !noMasking} 793 794 // Check for X-Forwarded-For header 795 if cips, ok := r.Header[wsXForwardedForHeader]; ok { 796 cip := cips[0] 797 if net.ParseIP(cip) != nil { 798 ws.clientIP = cip 799 } 800 } 801 802 if kind == CLIENT || kind == MQTT { 803 // Indicate if this is likely coming from a browser. 804 if ua := r.Header.Get("User-Agent"); ua != _EMPTY_ && strings.HasPrefix(ua, "Mozilla/") { 805 ws.browser = true 806 // Disable fragmentation of compressed frames for Safari browsers. 807 // Unfortunately, you could be running Chrome on macOS and this 808 // string will contain "Safari/" (along "Chrome/"). However, what 809 // I have found is that actual Safari browser also have "Version/". 810 // So make the combination of the two. 811 ws.nocompfrag = ws.compress && strings.Contains(ua, "Version/") && strings.Contains(ua, "Safari/") 812 } 813 814 if cookies := r.Cookies(); len(cookies) > 0 { 815 ows := &opts.Websocket 816 for _, c := range cookies { 817 if ows.JWTCookie == c.Name { 818 ws.cookieJwt = c.Value 819 } else if ows.UsernameCookie == c.Name { 820 ws.cookieUsername = c.Value 821 } else if ows.PasswordCookie == c.Name { 822 ws.cookiePassword = c.Value 823 } else if ows.TokenCookie == c.Name { 824 ws.cookieToken = c.Value 825 } 826 } 827 } 828 } 829 return &wsUpgradeResult{conn: conn, ws: ws, kind: kind}, nil 830 } 831 832 // Returns true if the header named `name` contains a token with value `value`. 833 func wsHeaderContains(header http.Header, name string, value string) bool { 834 for _, s := range header[name] { 835 tokens := strings.Split(s, ",") 836 for _, t := range tokens { 837 t = strings.Trim(t, " \t") 838 if strings.EqualFold(t, value) { 839 return true 840 } 841 } 842 } 843 return false 844 } 845 846 func wsPMCExtensionSupport(header http.Header, checkPMCOnly bool) (bool, bool) { 847 for _, extensionList := range header["Sec-Websocket-Extensions"] { 848 extensions := strings.Split(extensionList, ",") 849 for _, extension := range extensions { 850 extension = strings.Trim(extension, " \t") 851 params := strings.Split(extension, ";") 852 for i, p := range params { 853 p = strings.Trim(p, " \t") 854 if strings.EqualFold(p, wsPMCExtension) { 855 if checkPMCOnly { 856 return true, false 857 } 858 var snc bool 859 var cnc bool 860 for j := i + 1; j < len(params); j++ { 861 p = params[j] 862 p = strings.Trim(p, " \t") 863 if strings.EqualFold(p, wsPMCSrvNoCtx) { 864 snc = true 865 } else if strings.EqualFold(p, wsPMCCliNoCtx) { 866 cnc = true 867 } 868 if snc && cnc { 869 return true, true 870 } 871 } 872 return true, false 873 } 874 } 875 } 876 } 877 return false, false 878 } 879 880 // Send an HTTP error with the given `status` to the given http response writer `w`. 881 // Return an error created based on the `reason` string. 882 func wsReturnHTTPError(w http.ResponseWriter, r *http.Request, status int, reason string) error { 883 err := fmt.Errorf("%s - websocket handshake error: %s", r.RemoteAddr, reason) 884 w.Header().Set("Sec-Websocket-Version", "13") 885 http.Error(w, http.StatusText(status), status) 886 return err 887 } 888 889 // If the server is configured to accept any origin, then this function returns 890 // `nil` without checking if the Origin is present and valid. This is also 891 // the case if the request does not have the Origin header. 892 // Otherwise, this will check that the Origin matches the same origin or 893 // any origin in the allowed list. 894 func (w *srvWebsocket) checkOrigin(r *http.Request) error { 895 w.mu.RLock() 896 checkSame := w.sameOrigin 897 listEmpty := len(w.allowedOrigins) == 0 898 w.mu.RUnlock() 899 if !checkSame && listEmpty { 900 return nil 901 } 902 origin := r.Header.Get("Origin") 903 if origin == _EMPTY_ { 904 origin = r.Header.Get("Sec-Websocket-Origin") 905 } 906 // If the header is not present, we will accept. 907 // From https://datatracker.ietf.org/doc/html/rfc6455#section-1.6 908 // "Naturally, when the WebSocket Protocol is used by a dedicated client 909 // directly (i.e., not from a web page through a web browser), the origin 910 // model is not useful, as the client can provide any arbitrary origin string." 911 if origin == _EMPTY_ { 912 return nil 913 } 914 u, err := url.ParseRequestURI(origin) 915 if err != nil { 916 return err 917 } 918 oh, op, err := wsGetHostAndPort(u.Scheme == "https", u.Host) 919 if err != nil { 920 return err 921 } 922 // If checking same origin, compare with the http's request's Host. 923 if checkSame { 924 rh, rp, err := wsGetHostAndPort(r.TLS != nil, r.Host) 925 if err != nil { 926 return err 927 } 928 if oh != rh || op != rp { 929 return errors.New("not same origin") 930 } 931 // I guess it is possible to have cases where one wants to check 932 // same origin, but also that the origin is in the allowed list. 933 // So continue with the next check. 934 } 935 if !listEmpty { 936 w.mu.RLock() 937 ao := w.allowedOrigins[oh] 938 w.mu.RUnlock() 939 if ao == nil || u.Scheme != ao.scheme || op != ao.port { 940 return errors.New("not in the allowed list") 941 } 942 } 943 return nil 944 } 945 946 func wsGetHostAndPort(tls bool, hostport string) (string, string, error) { 947 host, port, err := net.SplitHostPort(hostport) 948 if err != nil { 949 // If error is missing port, then use defaults based on the scheme 950 if ae, ok := err.(*net.AddrError); ok && strings.Contains(ae.Err, "missing port") { 951 err = nil 952 host = hostport 953 if tls { 954 port = "443" 955 } else { 956 port = "80" 957 } 958 } 959 } 960 return strings.ToLower(host), port, err 961 } 962 963 // Concatenate the key sent by the client with the GUID, then computes the SHA1 hash 964 // and returns it as a based64 encoded string. 965 func wsAcceptKey(key string) string { 966 h := sha1.New() 967 h.Write([]byte(key)) 968 h.Write(wsGUID) 969 return base64.StdEncoding.EncodeToString(h.Sum(nil)) 970 } 971 972 func wsMakeChallengeKey() (string, error) { 973 p := make([]byte, 16) 974 if _, err := io.ReadFull(crand.Reader, p); err != nil { 975 return _EMPTY_, err 976 } 977 return base64.StdEncoding.EncodeToString(p), nil 978 } 979 980 // Validate the websocket related options. 981 func validateWebsocketOptions(o *Options) error { 982 wo := &o.Websocket 983 // If no port is defined, we don't care about other options 984 if wo.Port == 0 { 985 return nil 986 } 987 // Enforce TLS... unless NoTLS is set to true. 988 if wo.TLSConfig == nil && !wo.NoTLS { 989 return errors.New("websocket requires TLS configuration") 990 } 991 // Make sure that allowed origins, if specified, can be parsed. 992 for _, ao := range wo.AllowedOrigins { 993 if _, err := url.Parse(ao); err != nil { 994 return fmt.Errorf("unable to parse allowed origin: %v", err) 995 } 996 } 997 // If there is a NoAuthUser, we need to have Users defined and 998 // the user to be present. 999 if wo.NoAuthUser != _EMPTY_ { 1000 if err := validateNoAuthUser(o, wo.NoAuthUser); err != nil { 1001 return err 1002 } 1003 } 1004 // Token/Username not possible if there are users/nkeys 1005 if len(o.Users) > 0 || len(o.Nkeys) > 0 { 1006 if wo.Username != _EMPTY_ { 1007 return fmt.Errorf("websocket authentication username not compatible with presence of users/nkeys") 1008 } 1009 if wo.Token != _EMPTY_ { 1010 return fmt.Errorf("websocket authentication token not compatible with presence of users/nkeys") 1011 } 1012 } 1013 // Using JWT requires Trusted Keys 1014 if wo.JWTCookie != _EMPTY_ { 1015 if len(o.TrustedOperators) == 0 && len(o.TrustedKeys) == 0 { 1016 return fmt.Errorf("trusted operators or trusted keys configuration is required for JWT authentication via cookie %q", wo.JWTCookie) 1017 } 1018 } 1019 if err := validatePinnedCerts(wo.TLSPinnedCerts); err != nil { 1020 return fmt.Errorf("websocket: %v", err) 1021 } 1022 return nil 1023 } 1024 1025 // Creates or updates the existing map 1026 func (s *Server) wsSetOriginOptions(o *WebsocketOpts) { 1027 ws := &s.websocket 1028 ws.mu.Lock() 1029 defer ws.mu.Unlock() 1030 // Copy over the option's same origin boolean 1031 ws.sameOrigin = o.SameOrigin 1032 // Reset the map. Will help for config reload if/when we support it. 1033 ws.allowedOrigins = nil 1034 if o.AllowedOrigins == nil { 1035 return 1036 } 1037 for _, ao := range o.AllowedOrigins { 1038 // We have previously checked (during options validation) that the urls 1039 // are parseable, but if we get an error, report and skip. 1040 u, err := url.ParseRequestURI(ao) 1041 if err != nil { 1042 s.Errorf("error parsing allowed origin: %v", err) 1043 continue 1044 } 1045 h, p, _ := wsGetHostAndPort(u.Scheme == "https", u.Host) 1046 if ws.allowedOrigins == nil { 1047 ws.allowedOrigins = make(map[string]*allowedOrigin, len(o.AllowedOrigins)) 1048 } 1049 ws.allowedOrigins[h] = &allowedOrigin{scheme: u.Scheme, port: p} 1050 } 1051 } 1052 1053 // Given the websocket options, we check if any auth configuration 1054 // has been provided. If so, possibly create users/nkey users and 1055 // store them in s.websocket.users/nkeys. 1056 // Also update a boolean that indicates if auth is required for 1057 // websocket clients. 1058 // Server lock is held on entry. 1059 func (s *Server) wsConfigAuth(opts *WebsocketOpts) { 1060 ws := &s.websocket 1061 // If any of those is specified, we consider that there is an override. 1062 ws.authOverride = opts.Username != _EMPTY_ || opts.Token != _EMPTY_ || opts.NoAuthUser != _EMPTY_ 1063 } 1064 1065 func (s *Server) startWebsocketServer() { 1066 if s.isShuttingDown() { 1067 return 1068 } 1069 1070 sopts := s.getOpts() 1071 o := &sopts.Websocket 1072 1073 s.wsSetOriginOptions(o) 1074 1075 var hl net.Listener 1076 var proto string 1077 var err error 1078 1079 port := o.Port 1080 if port == -1 { 1081 port = 0 1082 } 1083 hp := net.JoinHostPort(o.Host, strconv.Itoa(port)) 1084 1085 // We are enforcing (when validating the options) the use of TLS, but the 1086 // code was originally supporting both modes. The reason for TLS only is 1087 // that we expect users to send JWTs with bearer tokens and we want to 1088 // avoid the possibility of it being "intercepted". 1089 1090 s.mu.Lock() 1091 // Do not check o.NoTLS here. If a TLS configuration is available, use it, 1092 // regardless of NoTLS. If we don't have a TLS config, it means that the 1093 // user has configured NoTLS because otherwise the server would have failed 1094 // to start due to options validation. 1095 if o.TLSConfig != nil { 1096 proto = wsSchemePrefixTLS 1097 config := o.TLSConfig.Clone() 1098 config.GetConfigForClient = s.wsGetTLSConfig 1099 hl, err = tls.Listen("tcp", hp, config) 1100 } else { 1101 proto = wsSchemePrefix 1102 hl, err = s.network.ListenCause("tcp", hp, "ws") 1103 } 1104 s.websocket.listenerErr = err 1105 if err != nil { 1106 s.mu.Unlock() 1107 s.Fatalf("Unable to listen for websocket connections: %v", err) 1108 return 1109 } 1110 if port == 0 { 1111 o.Port = hl.Addr().(*net.TCPAddr).Port 1112 } 1113 s.Noticef("Listening for websocket clients on %s://%s:%d", proto, o.Host, o.Port) 1114 if proto == wsSchemePrefix { 1115 s.Warnf("Websocket not configured with TLS. DO NOT USE IN PRODUCTION!") 1116 } 1117 1118 s.websocket.tls = proto == "wss" 1119 s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port) 1120 if err != nil { 1121 s.Fatalf("Unable to get websocket connect URLs: %v", err) 1122 hl.Close() 1123 s.mu.Unlock() 1124 return 1125 } 1126 hasLeaf := sopts.LeafNode.Port != 0 1127 mux := http.NewServeMux() 1128 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1129 res, err := s.wsUpgrade(w, r) 1130 if err != nil { 1131 s.Errorf(err.Error()) 1132 return 1133 } 1134 switch res.kind { 1135 case CLIENT: 1136 s.createWSClient(res.conn, res.ws) 1137 case MQTT: 1138 s.createMQTTClient(res.conn, res.ws) 1139 case LEAF: 1140 if !hasLeaf { 1141 s.Errorf("Not configured to accept leaf node connections") 1142 // Silently close for now. If we want to send an error back, we would 1143 // need to create the leafnode client anyway, so that is is handling websocket 1144 // frames, then send the error to the remote. 1145 res.conn.Close() 1146 return 1147 } 1148 s.createLeafNode(res.conn, nil, nil, res.ws) 1149 } 1150 }) 1151 hs := &http.Server{ 1152 Addr: hp, 1153 Handler: mux, 1154 ReadTimeout: o.HandshakeTimeout, 1155 ErrorLog: log.New(&captureHTTPServerLog{s, "websocket: "}, _EMPTY_, 0), 1156 } 1157 s.websocket.server = hs 1158 s.websocket.listener = hl 1159 go func() { 1160 if err := hs.Serve(hl); err != http.ErrServerClosed { 1161 s.Fatalf("websocket listener error: %v", err) 1162 } 1163 if s.isLameDuckMode() { 1164 // Signal that we are not accepting new clients 1165 s.ldmCh <- true 1166 // Now wait for the Shutdown... 1167 <-s.quitCh 1168 return 1169 } 1170 s.done <- true 1171 }() 1172 s.mu.Unlock() 1173 } 1174 1175 // The TLS configuration is passed to the listener when the websocket 1176 // "server" is setup. That prevents TLS configuration updates on reload 1177 // from being used. By setting this function in tls.Config.GetConfigForClient 1178 // we instruct the TLS handshake to ask for the tls configuration to be 1179 // used for a specific client. We don't care which client, we always use 1180 // the same TLS configuration. 1181 func (s *Server) wsGetTLSConfig(_ *tls.ClientHelloInfo) (*tls.Config, error) { 1182 opts := s.getOpts() 1183 return opts.Websocket.TLSConfig, nil 1184 } 1185 1186 // This is similar to createClient() but has some modifications 1187 // specific to handle websocket clients. 1188 // The comments have been kept to minimum to reduce code size. 1189 // Check createClient() for more details. 1190 func (s *Server) createWSClient(conn net.Conn, ws *websocket) *client { 1191 opts := s.getOpts() 1192 1193 maxPay := int32(opts.MaxPayload) 1194 maxSubs := int32(opts.MaxSubs) 1195 if maxSubs == 0 { 1196 maxSubs = -1 1197 } 1198 now := time.Now().UTC() 1199 1200 c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws} 1201 1202 c.registerWithAccount(s.globalAccount()) 1203 1204 var info Info 1205 var authRequired bool 1206 1207 s.mu.Lock() 1208 info = s.copyInfo() 1209 // Check auth, override if applicable. 1210 if !info.AuthRequired { 1211 // Set info.AuthRequired since this is what is sent to the client. 1212 info.AuthRequired = s.websocket.authOverride 1213 } 1214 if s.nonceRequired() { 1215 var raw [nonceLen]byte 1216 nonce := raw[:] 1217 s.generateNonce(nonce) 1218 info.Nonce = string(nonce) 1219 } 1220 c.nonce = []byte(info.Nonce) 1221 authRequired = info.AuthRequired 1222 1223 s.totalClients++ 1224 s.mu.Unlock() 1225 1226 c.mu.Lock() 1227 if authRequired { 1228 c.flags.set(expectConnect) 1229 } 1230 c.initClient() 1231 c.Debugf("Client connection created") 1232 c.sendProtoNow(c.generateClientInfoJSON(info)) 1233 c.mu.Unlock() 1234 1235 s.mu.Lock() 1236 if !s.isRunning() || s.ldm { 1237 if s.isShuttingDown() { 1238 conn.Close() 1239 } 1240 s.mu.Unlock() 1241 return c 1242 } 1243 1244 if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn { 1245 s.mu.Unlock() 1246 c.maxConnExceeded() 1247 return nil 1248 } 1249 s.clients[c.cid] = c 1250 s.mu.Unlock() 1251 1252 c.mu.Lock() 1253 // Websocket clients do TLS in the websocket http server. 1254 // So no TLS initiation here... 1255 if _, ok := conn.(*tls.Conn); ok { 1256 c.flags.set(handshakeComplete) 1257 } 1258 1259 if c.isClosed() { 1260 c.mu.Unlock() 1261 c.closeConnection(WriteError) 1262 return nil 1263 } 1264 1265 if authRequired { 1266 timeout := opts.AuthTimeout 1267 // Possibly override with Websocket specific value. 1268 if opts.Websocket.AuthTimeout != 0 { 1269 timeout = opts.Websocket.AuthTimeout 1270 } 1271 c.setAuthTimer(secondsToDuration(timeout)) 1272 } 1273 1274 c.setPingTimer() 1275 1276 s.startGoRoutine(func() { c.readLoop(nil) }) 1277 s.startGoRoutine(func() { c.writeLoop() }) 1278 1279 c.mu.Unlock() 1280 1281 return c 1282 } 1283 1284 func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { 1285 nb := c.out.nb 1286 var mfs int 1287 var usz int 1288 if c.ws.browser { 1289 mfs = wsFrameSizeForBrowsers 1290 } 1291 mask := c.ws.maskwrite 1292 // Start with possible already framed buffers (that we could have 1293 // got from partials or control messages such as ws pings or pongs). 1294 bufs := c.ws.frames 1295 compress := c.ws.compress 1296 if compress && len(nb) > 0 { 1297 // First, make sure we don't compress for very small cumulative buffers. 1298 for _, b := range nb { 1299 usz += len(b) 1300 } 1301 if usz <= wsCompressThreshold { 1302 compress = false 1303 } 1304 } 1305 if compress && len(nb) > 0 { 1306 // Overwrite mfs if this connection does not support fragmented compressed frames. 1307 if mfs > 0 && c.ws.nocompfrag { 1308 mfs = 0 1309 } 1310 buf := bytes.NewBuffer(nbPoolGet(usz)) 1311 cp := c.ws.compressor 1312 if cp == nil { 1313 c.ws.compressor, _ = flate.NewWriter(buf, flate.BestSpeed) 1314 cp = c.ws.compressor 1315 } else { 1316 cp.Reset(buf) 1317 } 1318 var csz int 1319 for _, b := range nb { 1320 cp.Write(b) 1321 nbPoolPut(b) // No longer needed as contents written to compressor. 1322 } 1323 if err := cp.Flush(); err != nil { 1324 c.Errorf("Error during compression: %v", err) 1325 c.markConnAsClosed(WriteError) 1326 return nil, 0 1327 } 1328 b := buf.Bytes() 1329 p := b[:len(b)-4] 1330 if mfs > 0 && len(p) > mfs { 1331 for first, final := true, false; len(p) > 0; first = false { 1332 lp := len(p) 1333 if lp > mfs { 1334 lp = mfs 1335 } else { 1336 final = true 1337 } 1338 // Only the first frame should be marked as compressed, so pass 1339 // `first` for the compressed boolean. 1340 fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize] 1341 n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp) 1342 if mask { 1343 wsMaskBuf(key, p[:lp]) 1344 } 1345 bufs = append(bufs, fh[:n], p[:lp]) 1346 csz += n + lp 1347 p = p[lp:] 1348 } 1349 } else { 1350 ol := len(p) 1351 h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, ol) 1352 if mask { 1353 wsMaskBuf(key, p) 1354 } 1355 if ol > 0 { 1356 bufs = append(bufs, h, p) 1357 } 1358 csz = len(h) + ol 1359 } 1360 // Make sure that the compressor no longer holds a reference to 1361 // the bytes.Buffer, so that the underlying memory gets cleaned 1362 // up after flushOutbound/flushAndClose. For this to be safe, we 1363 // always cp.Reset(...) before reusing the compressor again. 1364 cp.Reset(nil) 1365 // Add to pb the compressed data size (including headers), but 1366 // remove the original uncompressed data size that was added 1367 // during the queueing. 1368 c.out.pb += int64(csz) - int64(usz) 1369 c.ws.fs += int64(csz) 1370 } else if len(nb) > 0 { 1371 var total int 1372 if mfs > 0 { 1373 // We are limiting the frame size. 1374 startFrame := func() int { 1375 bufs = append(bufs, nbPoolGet(wsMaxFrameHeaderSize)) 1376 return len(bufs) - 1 1377 } 1378 endFrame := func(idx, size int) { 1379 bufs[idx] = bufs[idx][:wsMaxFrameHeaderSize] 1380 n, key := wsFillFrameHeader(bufs[idx], mask, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, wsBinaryMessage, size) 1381 bufs[idx] = bufs[idx][:n] 1382 c.out.pb += int64(n) 1383 c.ws.fs += int64(n + size) 1384 if mask { 1385 wsMaskBufs(key, bufs[idx+1:]) 1386 } 1387 } 1388 1389 fhIdx := startFrame() 1390 for i := 0; i < len(nb); i++ { 1391 b := nb[i] 1392 if total+len(b) <= mfs { 1393 buf := nbPoolGet(len(b)) 1394 bufs = append(bufs, append(buf, b...)) 1395 total += len(b) 1396 nbPoolPut(nb[i]) 1397 continue 1398 } 1399 for len(b) > 0 { 1400 endStart := total != 0 1401 if endStart { 1402 endFrame(fhIdx, total) 1403 } 1404 total = len(b) 1405 if total >= mfs { 1406 total = mfs 1407 } 1408 if endStart { 1409 fhIdx = startFrame() 1410 } 1411 buf := nbPoolGet(total) 1412 bufs = append(bufs, append(buf, b[:total]...)) 1413 b = b[total:] 1414 } 1415 nbPoolPut(nb[i]) // No longer needed as copied into smaller frames. 1416 } 1417 if total > 0 { 1418 endFrame(fhIdx, total) 1419 } 1420 } else { 1421 // If there is no limit on the frame size, create a single frame for 1422 // all pending buffers. 1423 for _, b := range nb { 1424 total += len(b) 1425 } 1426 wsfh, key := wsCreateFrameHeader(mask, false, wsBinaryMessage, total) 1427 c.out.pb += int64(len(wsfh)) 1428 bufs = append(bufs, wsfh) 1429 idx := len(bufs) 1430 bufs = append(bufs, nb...) 1431 if mask { 1432 wsMaskBufs(key, bufs[idx:]) 1433 } 1434 c.ws.fs += int64(len(wsfh) + total) 1435 } 1436 } 1437 if len(c.ws.closeMsg) > 0 { 1438 bufs = append(bufs, c.ws.closeMsg) 1439 c.ws.fs += int64(len(c.ws.closeMsg)) 1440 c.ws.closeMsg = nil 1441 } 1442 c.ws.frames = nil 1443 return bufs, c.ws.fs 1444 } 1445 1446 func isWSURL(u *url.URL) bool { 1447 return strings.HasPrefix(strings.ToLower(u.Scheme), wsSchemePrefix) 1448 } 1449 1450 func isWSSURL(u *url.URL) bool { 1451 return strings.HasPrefix(strings.ToLower(u.Scheme), wsSchemePrefixTLS) 1452 }