github.com/nats-io/nats-server/v2@v2.11.0-preview.2/server/websocket.go (about) 1 // Copyright 2020-2024 The NATS Authors 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package server 15 16 import ( 17 "bytes" 18 crand "crypto/rand" 19 "crypto/sha1" 20 "crypto/tls" 21 "encoding/base64" 22 "encoding/binary" 23 "errors" 24 "fmt" 25 "io" 26 "log" 27 mrand "math/rand" 28 "net" 29 "net/http" 30 "net/url" 31 "strconv" 32 "strings" 33 "sync" 34 "time" 35 "unicode/utf8" 36 37 "github.com/klauspost/compress/flate" 38 ) 39 40 type wsOpCode int 41 42 const ( 43 // From https://tools.ietf.org/html/rfc6455#section-5.2 44 wsTextMessage = wsOpCode(1) 45 wsBinaryMessage = wsOpCode(2) 46 wsCloseMessage = wsOpCode(8) 47 wsPingMessage = wsOpCode(9) 48 wsPongMessage = wsOpCode(10) 49 50 wsFinalBit = 1 << 7 51 wsRsv1Bit = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6 52 wsRsv2Bit = 1 << 5 53 wsRsv3Bit = 1 << 4 54 55 wsMaskBit = 1 << 7 56 57 wsContinuationFrame = 0 58 wsMaxFrameHeaderSize = 14 // Since LeafNode may need to behave as a client 59 wsMaxControlPayloadSize = 125 60 wsFrameSizeForBrowsers = 4096 // From experiment, webrowsers behave better with limited frame size 61 wsCompressThreshold = 64 // Don't compress for small buffer(s) 62 wsCloseSatusSize = 2 63 64 // From https://tools.ietf.org/html/rfc6455#section-11.7 65 wsCloseStatusNormalClosure = 1000 66 wsCloseStatusGoingAway = 1001 67 wsCloseStatusProtocolError = 1002 68 wsCloseStatusUnsupportedData = 1003 69 wsCloseStatusNoStatusReceived = 1005 70 wsCloseStatusAbnormalClosure = 1006 71 wsCloseStatusInvalidPayloadData = 1007 72 wsCloseStatusPolicyViolation = 1008 73 wsCloseStatusMessageTooBig = 1009 74 wsCloseStatusInternalSrvError = 1011 75 wsCloseStatusTLSHandshake = 1015 76 77 wsFirstFrame = true 78 wsContFrame = false 79 wsFinalFrame = true 80 wsUncompressedFrame = false 81 82 wsSchemePrefix = "ws" 83 wsSchemePrefixTLS = "wss" 84 85 wsNoMaskingHeader = "Nats-No-Masking" 86 wsNoMaskingValue = "true" 87 wsXForwardedForHeader = "X-Forwarded-For" 88 wsNoMaskingFullResponse = wsNoMaskingHeader + ": " + wsNoMaskingValue + CR_LF 89 wsPMCExtension = "permessage-deflate" // per-message compression 90 wsPMCSrvNoCtx = "server_no_context_takeover" 91 wsPMCCliNoCtx = "client_no_context_takeover" 92 wsPMCReqHeaderValue = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx 93 wsPMCFullResponse = "Sec-WebSocket-Extensions: " + wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx + _CRLF_ 94 wsSecProto = "Sec-Websocket-Protocol" 95 wsMQTTSecProtoVal = "mqtt" 96 wsMQTTSecProto = wsSecProto + ": " + wsMQTTSecProtoVal + CR_LF 97 ) 98 99 var decompressorPool sync.Pool 100 var compressLastBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff} 101 102 // From https://tools.ietf.org/html/rfc6455#section-1.3 103 var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11") 104 105 // Test can enable this so that server does not support "no-masking" requests. 106 var wsTestRejectNoMasking = false 107 108 type websocket struct { 109 frames net.Buffers 110 fs int64 111 closeMsg []byte 112 compress bool 113 closeSent bool 114 browser bool 115 nocompfrag bool // No fragment for compressed frames 116 maskread bool 117 maskwrite bool 118 compressor *flate.Writer 119 cookieJwt string 120 cookieUsername string 121 cookiePassword string 122 cookieToken string 123 clientIP string 124 } 125 126 type srvWebsocket struct { 127 mu sync.RWMutex 128 server *http.Server 129 listener net.Listener 130 listenerErr error 131 tls bool 132 allowedOrigins map[string]*allowedOrigin // host will be the key 133 sameOrigin bool 134 connectURLs []string 135 connectURLsMap refCountedUrlSet 136 authOverride bool // indicate if there is auth override in websocket config 137 rawHeaders string // raw headers to be used in the upgrade response. 138 } 139 140 type allowedOrigin struct { 141 scheme string 142 port string 143 } 144 145 type wsUpgradeResult struct { 146 conn net.Conn 147 ws *websocket 148 kind int 149 } 150 151 type wsReadInfo struct { 152 rem int 153 fs bool 154 ff bool 155 fc bool 156 mask bool // Incoming leafnode connections may not have masking. 157 mkpos byte 158 mkey [4]byte 159 cbufs [][]byte 160 coff int 161 } 162 163 func (r *wsReadInfo) init() { 164 r.fs, r.ff = true, true 165 } 166 167 // Returns a slice containing `needed` bytes from the given buffer `buf` 168 // starting at position `pos`, and possibly read from the given reader `r`. 169 // When bytes are present in `buf`, the `pos` is incremented by the number 170 // of bytes found up to `needed` and the new position is returned. If not 171 // enough bytes are found, the bytes found in `buf` are copied to the returned 172 // slice and the remaning bytes are read from `r`. 173 func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) { 174 avail := len(buf) - pos 175 if avail >= needed { 176 return buf[pos : pos+needed], pos + needed, nil 177 } 178 b := make([]byte, needed) 179 start := copy(b, buf[pos:]) 180 for start != needed { 181 n, err := r.Read(b[start:cap(b)]) 182 if err != nil { 183 return nil, 0, err 184 } 185 start += n 186 } 187 return b, pos + avail, nil 188 } 189 190 // Returns true if this connection is from a Websocket client. 191 // Lock held on entry. 192 func (c *client) isWebsocket() bool { 193 return c.ws != nil 194 } 195 196 // Returns a slice of byte slices corresponding to payload of websocket frames. 197 // The byte slice `buf` is filled with bytes from the connection's read loop. 198 // This function will decode the frame headers and unmask the payload(s). 199 // It is possible that the returned slices point to the given `buf` slice, so 200 // `buf` should not be overwritten until the returned slices have been parsed. 201 // 202 // Client lock MUST NOT be held on entry. 203 func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, error) { 204 var ( 205 bufs [][]byte 206 tmpBuf []byte 207 err error 208 pos int 209 max = len(buf) 210 ) 211 for pos != max { 212 if r.fs { 213 b0 := buf[pos] 214 frameType := wsOpCode(b0 & 0xF) 215 final := b0&wsFinalBit != 0 216 compressed := b0&wsRsv1Bit != 0 217 pos++ 218 219 tmpBuf, pos, err = wsGet(ior, buf, pos, 1) 220 if err != nil { 221 return bufs, err 222 } 223 b1 := tmpBuf[0] 224 225 // Clients MUST set the mask bit. If not set, reject. 226 // However, LEAF by default will not have masking, unless they are forced to, by configuration. 227 if r.mask && b1&wsMaskBit == 0 { 228 return bufs, c.wsHandleProtocolError("mask bit missing") 229 } 230 231 // Store size in case it is < 125 232 r.rem = int(b1 & 0x7F) 233 234 switch frameType { 235 case wsPingMessage, wsPongMessage, wsCloseMessage: 236 if r.rem > wsMaxControlPayloadSize { 237 return bufs, c.wsHandleProtocolError( 238 fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes", 239 wsMaxControlPayloadSize)) 240 } 241 if !final { 242 return bufs, c.wsHandleProtocolError("control frame does not have final bit set") 243 } 244 case wsTextMessage, wsBinaryMessage: 245 if !r.ff { 246 return bufs, c.wsHandleProtocolError("new message started before final frame for previous message was received") 247 } 248 r.ff = final 249 r.fc = compressed 250 case wsContinuationFrame: 251 // Compressed bit must be only set in the first frame 252 if r.ff || compressed { 253 return bufs, c.wsHandleProtocolError("invalid continuation frame") 254 } 255 r.ff = final 256 default: 257 return bufs, c.wsHandleProtocolError(fmt.Sprintf("unknown opcode %v", frameType)) 258 } 259 260 switch r.rem { 261 case 126: 262 tmpBuf, pos, err = wsGet(ior, buf, pos, 2) 263 if err != nil { 264 return bufs, err 265 } 266 r.rem = int(binary.BigEndian.Uint16(tmpBuf)) 267 case 127: 268 tmpBuf, pos, err = wsGet(ior, buf, pos, 8) 269 if err != nil { 270 return bufs, err 271 } 272 r.rem = int(binary.BigEndian.Uint64(tmpBuf)) 273 } 274 275 if r.mask { 276 // Read masking key 277 tmpBuf, pos, err = wsGet(ior, buf, pos, 4) 278 if err != nil { 279 return bufs, err 280 } 281 copy(r.mkey[:], tmpBuf) 282 r.mkpos = 0 283 } 284 285 // Handle control messages in place... 286 if wsIsControlFrame(frameType) { 287 pos, err = c.wsHandleControlFrame(r, frameType, ior, buf, pos) 288 if err != nil { 289 return bufs, err 290 } 291 continue 292 } 293 294 // Done with the frame header 295 r.fs = false 296 } 297 if pos < max { 298 var b []byte 299 var n int 300 301 n = r.rem 302 if pos+n > max { 303 n = max - pos 304 } 305 b = buf[pos : pos+n] 306 pos += n 307 r.rem -= n 308 // If needed, unmask the buffer 309 if r.mask { 310 r.unmask(b) 311 } 312 addToBufs := true 313 // Handle compressed message 314 if r.fc { 315 // Assume that we may have continuation frames or not the full payload. 316 addToBufs = false 317 // Make a copy of the buffer before adding it to the list 318 // of compressed fragments. 319 r.cbufs = append(r.cbufs, append([]byte(nil), b...)) 320 // When we have the final frame and we have read the full payload, 321 // we can decompress it. 322 if r.ff && r.rem == 0 { 323 b, err = r.decompress() 324 if err != nil { 325 return bufs, err 326 } 327 r.fc = false 328 // Now we can add to `bufs` 329 addToBufs = true 330 } 331 } 332 // For non compressed frames, or when we have decompressed the 333 // whole message. 334 if addToBufs { 335 bufs = append(bufs, b) 336 } 337 // If payload has been fully read, then indicate that next 338 // is the start of a frame. 339 if r.rem == 0 { 340 r.fs = true 341 } 342 } 343 } 344 return bufs, nil 345 } 346 347 func (r *wsReadInfo) Read(dst []byte) (int, error) { 348 if len(dst) == 0 { 349 return 0, nil 350 } 351 if len(r.cbufs) == 0 { 352 return 0, io.EOF 353 } 354 copied := 0 355 rem := len(dst) 356 for buf := r.cbufs[0]; buf != nil && rem > 0; { 357 n := len(buf[r.coff:]) 358 if n > rem { 359 n = rem 360 } 361 copy(dst[copied:], buf[r.coff:r.coff+n]) 362 copied += n 363 rem -= n 364 r.coff += n 365 buf = r.nextCBuf() 366 } 367 return copied, nil 368 } 369 370 func (r *wsReadInfo) nextCBuf() []byte { 371 // We still have remaining data in the first buffer 372 if r.coff != len(r.cbufs[0]) { 373 return r.cbufs[0] 374 } 375 // We read the full first buffer. Reset offset. 376 r.coff = 0 377 // We were at the last buffer, so we are done. 378 if len(r.cbufs) == 1 { 379 r.cbufs = nil 380 return nil 381 } 382 // Here we move to the next buffer. 383 r.cbufs = r.cbufs[1:] 384 return r.cbufs[0] 385 } 386 387 func (r *wsReadInfo) ReadByte() (byte, error) { 388 if len(r.cbufs) == 0 { 389 return 0, io.EOF 390 } 391 b := r.cbufs[0][r.coff] 392 r.coff++ 393 r.nextCBuf() 394 return b, nil 395 } 396 397 func (r *wsReadInfo) decompress() ([]byte, error) { 398 r.coff = 0 399 // As per https://tools.ietf.org/html/rfc7692#section-7.2.2 400 // add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader 401 // does not report unexpected EOF. 402 r.cbufs = append(r.cbufs, compressLastBlock) 403 // Get a decompressor from the pool and bind it to this object (wsReadInfo) 404 // that provides Read() and ReadByte() APIs that will consume the compressed 405 // buffers (r.cbufs). 406 d, _ := decompressorPool.Get().(io.ReadCloser) 407 if d == nil { 408 d = flate.NewReader(r) 409 } else { 410 d.(flate.Resetter).Reset(r, nil) 411 } 412 // This will do the decompression. 413 b, err := io.ReadAll(d) 414 decompressorPool.Put(d) 415 // Now reset the compressed buffers list. 416 r.cbufs = nil 417 return b, err 418 } 419 420 // Handles the PING, PONG and CLOSE websocket control frames. 421 // 422 // Client lock MUST NOT be held on entry. 423 func (c *client) wsHandleControlFrame(r *wsReadInfo, frameType wsOpCode, nc io.Reader, buf []byte, pos int) (int, error) { 424 var payload []byte 425 var err error 426 427 if r.rem > 0 { 428 payload, pos, err = wsGet(nc, buf, pos, r.rem) 429 if err != nil { 430 return pos, err 431 } 432 if r.mask { 433 r.unmask(payload) 434 } 435 r.rem = 0 436 } 437 switch frameType { 438 case wsCloseMessage: 439 status := wsCloseStatusNoStatusReceived 440 var body string 441 lp := len(payload) 442 // If there is a payload, the status is represented as a 2-byte 443 // unsigned integer (in network byte order). Then, there may be an 444 // optional body. 445 hasStatus, hasBody := lp >= wsCloseSatusSize, lp > wsCloseSatusSize 446 if hasStatus { 447 // Decode the status 448 status = int(binary.BigEndian.Uint16(payload[:wsCloseSatusSize])) 449 // Now if there is a body, capture it and make sure this is a valid UTF-8. 450 if hasBody { 451 body = string(payload[wsCloseSatusSize:]) 452 if !utf8.ValidString(body) { 453 // https://tools.ietf.org/html/rfc6455#section-5.5.1 454 // If body is present, it must be a valid utf8 455 status = wsCloseStatusInvalidPayloadData 456 body = "invalid utf8 body in close frame" 457 } 458 } 459 } 460 clm := wsCreateCloseMessage(status, body) 461 c.wsEnqueueControlMessage(wsCloseMessage, clm) 462 nbPoolPut(clm) // wsEnqueueControlMessage has taken a copy. 463 // Return io.EOF so that readLoop will close the connection as ClientClosed 464 // after processing pending buffers. 465 return pos, io.EOF 466 case wsPingMessage: 467 c.wsEnqueueControlMessage(wsPongMessage, payload) 468 case wsPongMessage: 469 // Nothing to do.. 470 } 471 return pos, nil 472 } 473 474 // Unmask the given slice. 475 func (r *wsReadInfo) unmask(buf []byte) { 476 p := int(r.mkpos) 477 if len(buf) < 16 { 478 for i := 0; i < len(buf); i++ { 479 buf[i] ^= r.mkey[p&3] 480 p++ 481 } 482 r.mkpos = byte(p & 3) 483 return 484 } 485 var k [8]byte 486 for i := 0; i < 8; i++ { 487 k[i] = r.mkey[(p+i)&3] 488 } 489 km := binary.BigEndian.Uint64(k[:]) 490 n := (len(buf) / 8) * 8 491 for i := 0; i < n; i += 8 { 492 tmp := binary.BigEndian.Uint64(buf[i : i+8]) 493 tmp ^= km 494 binary.BigEndian.PutUint64(buf[i:], tmp) 495 } 496 buf = buf[n:] 497 for i := 0; i < len(buf); i++ { 498 buf[i] ^= r.mkey[p&3] 499 p++ 500 } 501 r.mkpos = byte(p & 3) 502 } 503 504 // Returns true if the op code corresponds to a control frame. 505 func wsIsControlFrame(frameType wsOpCode) bool { 506 return frameType >= wsCloseMessage 507 } 508 509 // Create the frame header. 510 // Encodes the frame type and optional compression flag, and the size of the payload. 511 func wsCreateFrameHeader(useMasking, compressed bool, frameType wsOpCode, l int) ([]byte, []byte) { 512 fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize] 513 n, key := wsFillFrameHeader(fh, useMasking, wsFirstFrame, wsFinalFrame, compressed, frameType, l) 514 return fh[:n], key 515 } 516 517 func wsFillFrameHeader(fh []byte, useMasking, first, final, compressed bool, frameType wsOpCode, l int) (int, []byte) { 518 var n int 519 var b byte 520 if first { 521 b = byte(frameType) 522 } 523 if final { 524 b |= wsFinalBit 525 } 526 if compressed { 527 b |= wsRsv1Bit 528 } 529 b1 := byte(0) 530 if useMasking { 531 b1 |= wsMaskBit 532 } 533 switch { 534 case l <= 125: 535 n = 2 536 fh[0] = b 537 fh[1] = b1 | byte(l) 538 case l < 65536: 539 n = 4 540 fh[0] = b 541 fh[1] = b1 | 126 542 binary.BigEndian.PutUint16(fh[2:], uint16(l)) 543 default: 544 n = 10 545 fh[0] = b 546 fh[1] = b1 | 127 547 binary.BigEndian.PutUint64(fh[2:], uint64(l)) 548 } 549 var key []byte 550 if useMasking { 551 var keyBuf [4]byte 552 if _, err := io.ReadFull(crand.Reader, keyBuf[:4]); err != nil { 553 kv := mrand.Int31() 554 binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv)) 555 } 556 copy(fh[n:], keyBuf[:4]) 557 key = fh[n : n+4] 558 n += 4 559 } 560 return n, key 561 } 562 563 // Invokes wsEnqueueControlMessageLocked under client lock. 564 // 565 // Client lock MUST NOT be held on entry 566 func (c *client) wsEnqueueControlMessage(controlMsg wsOpCode, payload []byte) { 567 c.mu.Lock() 568 c.wsEnqueueControlMessageLocked(controlMsg, payload) 569 c.mu.Unlock() 570 } 571 572 // Mask the buffer with the given key 573 func wsMaskBuf(key, buf []byte) { 574 for i := 0; i < len(buf); i++ { 575 buf[i] ^= key[i&3] 576 } 577 } 578 579 // Mask the buffers, as if they were contiguous, with the given key 580 func wsMaskBufs(key []byte, bufs [][]byte) { 581 pos := 0 582 for i := 0; i < len(bufs); i++ { 583 buf := bufs[i] 584 for j := 0; j < len(buf); j++ { 585 buf[j] ^= key[pos&3] 586 pos++ 587 } 588 } 589 } 590 591 // Enqueues a websocket control message. 592 // If the control message is a wsCloseMessage, then marks this client 593 // has having sent the close message (since only one should be sent). 594 // This will prevent the generic closeConnection() to enqueue one. 595 // 596 // Client lock held on entry. 597 func (c *client) wsEnqueueControlMessageLocked(controlMsg wsOpCode, payload []byte) { 598 // Control messages are never compressed and their size will be 599 // less than wsMaxControlPayloadSize, which means the frame header 600 // will be only 2 or 6 bytes. 601 useMasking := c.ws.maskwrite 602 sz := 2 603 if useMasking { 604 sz += 4 605 } 606 cm := nbPoolGet(sz + len(payload)) 607 cm = cm[:cap(cm)] 608 n, key := wsFillFrameHeader(cm, useMasking, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, controlMsg, len(payload)) 609 cm = cm[:n] 610 // Note that payload is optional. 611 if len(payload) > 0 { 612 cm = append(cm, payload...) 613 if useMasking { 614 wsMaskBuf(key, cm[n:]) 615 } 616 } 617 c.out.pb += int64(len(cm)) 618 if controlMsg == wsCloseMessage { 619 // We can't add the close message to the frames buffers 620 // now. It will be done on a flushOutbound() when there 621 // are no more pending buffers to send. 622 c.ws.closeSent = true 623 c.ws.closeMsg = cm 624 } else { 625 c.ws.frames = append(c.ws.frames, cm) 626 c.ws.fs += int64(len(cm)) 627 } 628 c.flushSignal() 629 } 630 631 // Enqueues a websocket close message with a status mapped from the given `reason`. 632 // 633 // Client lock held on entry 634 func (c *client) wsEnqueueCloseMessage(reason ClosedState) { 635 var status int 636 switch reason { 637 case ClientClosed: 638 status = wsCloseStatusNormalClosure 639 case AuthenticationTimeout, AuthenticationViolation, SlowConsumerPendingBytes, SlowConsumerWriteDeadline, 640 MaxAccountConnectionsExceeded, MaxConnectionsExceeded, MaxControlLineExceeded, MaxSubscriptionsExceeded, 641 MissingAccount, AuthenticationExpired, Revocation: 642 status = wsCloseStatusPolicyViolation 643 case TLSHandshakeError: 644 status = wsCloseStatusTLSHandshake 645 case ParseError, ProtocolViolation, BadClientProtocolVersion: 646 status = wsCloseStatusProtocolError 647 case MaxPayloadExceeded: 648 status = wsCloseStatusMessageTooBig 649 case ServerShutdown: 650 status = wsCloseStatusGoingAway 651 case WriteError, ReadError, StaleConnection: 652 status = wsCloseStatusAbnormalClosure 653 default: 654 status = wsCloseStatusInternalSrvError 655 } 656 body := wsCreateCloseMessage(status, reason.String()) 657 c.wsEnqueueControlMessageLocked(wsCloseMessage, body) 658 nbPoolPut(body) // wsEnqueueControlMessageLocked has taken a copy. 659 } 660 661 // Create and then enqueue a close message with a protocol error and the 662 // given message. This is invoked when parsing websocket frames. 663 // 664 // Lock MUST NOT be held on entry. 665 func (c *client) wsHandleProtocolError(message string) error { 666 buf := wsCreateCloseMessage(wsCloseStatusProtocolError, message) 667 c.wsEnqueueControlMessage(wsCloseMessage, buf) 668 nbPoolPut(buf) // wsEnqueueControlMessage has taken a copy. 669 return fmt.Errorf(message) 670 } 671 672 // Create a close message with the given `status` and `body`. 673 // If the `body` is more than the maximum allows control frame payload size, 674 // it is truncated and "..." is added at the end (as a hint that message 675 // is not complete). 676 func wsCreateCloseMessage(status int, body string) []byte { 677 // Since a control message payload is limited in size, we 678 // will limit the text and add trailing "..." if truncated. 679 // The body of a Close Message must be preceded with 2 bytes, 680 // so take that into account for limiting the body length. 681 if len(body) > wsMaxControlPayloadSize-2 { 682 body = body[:wsMaxControlPayloadSize-5] 683 body += "..." 684 } 685 buf := nbPoolGet(2 + len(body))[:2+len(body)] 686 // We need to have a 2 byte unsigned int that represents the error status code 687 // https://tools.ietf.org/html/rfc6455#section-5.5.1 688 binary.BigEndian.PutUint16(buf[:2], uint16(status)) 689 copy(buf[2:], []byte(body)) 690 return buf 691 } 692 693 // Process websocket client handshake. On success, returns the raw net.Conn that 694 // will be used to create a *client object. 695 // Invoked from the HTTP server listening on websocket port. 696 func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeResult, error) { 697 kind := CLIENT 698 if r.URL != nil { 699 ep := r.URL.EscapedPath() 700 if strings.HasSuffix(ep, leafNodeWSPath) { 701 kind = LEAF 702 } else if strings.HasSuffix(ep, mqttWSPath) { 703 kind = MQTT 704 } 705 } 706 707 opts := s.getOpts() 708 709 // From https://tools.ietf.org/html/rfc6455#section-4.2.1 710 // Point 1. 711 if r.Method != "GET" { 712 return nil, wsReturnHTTPError(w, r, http.StatusMethodNotAllowed, "request method must be GET") 713 } 714 // Point 2. 715 if r.Host == _EMPTY_ { 716 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "'Host' missing in request") 717 } 718 // Point 3. 719 if !wsHeaderContains(r.Header, "Upgrade", "websocket") { 720 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Upgrade'") 721 } 722 // Point 4. 723 if !wsHeaderContains(r.Header, "Connection", "Upgrade") { 724 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Connection'") 725 } 726 // Point 5. 727 key := r.Header.Get("Sec-Websocket-Key") 728 if key == _EMPTY_ { 729 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "key missing") 730 } 731 // Point 6. 732 if !wsHeaderContains(r.Header, "Sec-Websocket-Version", "13") { 733 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid version") 734 } 735 // Others are optional 736 // Point 7. 737 if err := s.websocket.checkOrigin(r); err != nil { 738 return nil, wsReturnHTTPError(w, r, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err)) 739 } 740 // Point 8. 741 // We don't have protocols, so ignore. 742 // Point 9. 743 // Extensions, only support for compression at the moment 744 compress := opts.Websocket.Compression 745 if compress { 746 // Simply check if permessage-deflate extension is present. 747 compress, _ = wsPMCExtensionSupport(r.Header, true) 748 } 749 // We will do masking if asked (unless we reject for tests) 750 noMasking := r.Header.Get(wsNoMaskingHeader) == wsNoMaskingValue && !wsTestRejectNoMasking 751 752 h := w.(http.Hijacker) 753 conn, brw, err := h.Hijack() 754 if err != nil { 755 if conn != nil { 756 conn.Close() 757 } 758 return nil, wsReturnHTTPError(w, r, http.StatusInternalServerError, err.Error()) 759 } 760 if brw.Reader.Buffered() > 0 { 761 conn.Close() 762 return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "client sent data before handshake is complete") 763 } 764 765 var buf [1024]byte 766 p := buf[:0] 767 768 // From https://tools.ietf.org/html/rfc6455#section-4.2.2 769 p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...) 770 p = append(p, wsAcceptKey(key)...) 771 p = append(p, _CRLF_...) 772 if compress { 773 p = append(p, wsPMCFullResponse...) 774 } 775 if noMasking { 776 p = append(p, wsNoMaskingFullResponse...) 777 } 778 if kind == MQTT { 779 p = append(p, wsMQTTSecProto...) 780 } 781 if s.websocket.rawHeaders != _EMPTY_ { 782 p = append(p, s.websocket.rawHeaders...) 783 } 784 p = append(p, _CRLF_...) 785 786 if _, err = conn.Write(p); err != nil { 787 conn.Close() 788 return nil, err 789 } 790 // If there was a deadline set for the handshake, clear it now. 791 if opts.Websocket.HandshakeTimeout > 0 { 792 conn.SetDeadline(time.Time{}) 793 } 794 // Server always expect "clients" to send masked payload, unless the option 795 // "no-masking" has been enabled. 796 ws := &websocket{compress: compress, maskread: !noMasking} 797 798 // Check for X-Forwarded-For header 799 if cips, ok := r.Header[wsXForwardedForHeader]; ok { 800 cip := cips[0] 801 if net.ParseIP(cip) != nil { 802 ws.clientIP = cip 803 } 804 } 805 806 if kind == CLIENT || kind == MQTT { 807 // Indicate if this is likely coming from a browser. 808 if ua := r.Header.Get("User-Agent"); ua != _EMPTY_ && strings.HasPrefix(ua, "Mozilla/") { 809 ws.browser = true 810 // Disable fragmentation of compressed frames for Safari browsers. 811 // Unfortunately, you could be running Chrome on macOS and this 812 // string will contain "Safari/" (along "Chrome/"). However, what 813 // I have found is that actual Safari browser also have "Version/". 814 // So make the combination of the two. 815 ws.nocompfrag = ws.compress && strings.Contains(ua, "Version/") && strings.Contains(ua, "Safari/") 816 } 817 818 if cookies := r.Cookies(); len(cookies) > 0 { 819 ows := &opts.Websocket 820 for _, c := range cookies { 821 if ows.JWTCookie == c.Name { 822 ws.cookieJwt = c.Value 823 } else if ows.UsernameCookie == c.Name { 824 ws.cookieUsername = c.Value 825 } else if ows.PasswordCookie == c.Name { 826 ws.cookiePassword = c.Value 827 } else if ows.TokenCookie == c.Name { 828 ws.cookieToken = c.Value 829 } 830 } 831 } 832 } 833 return &wsUpgradeResult{conn: conn, ws: ws, kind: kind}, nil 834 } 835 836 // Returns true if the header named `name` contains a token with value `value`. 837 func wsHeaderContains(header http.Header, name string, value string) bool { 838 for _, s := range header[name] { 839 tokens := strings.Split(s, ",") 840 for _, t := range tokens { 841 t = strings.Trim(t, " \t") 842 if strings.EqualFold(t, value) { 843 return true 844 } 845 } 846 } 847 return false 848 } 849 850 func wsPMCExtensionSupport(header http.Header, checkPMCOnly bool) (bool, bool) { 851 for _, extensionList := range header["Sec-Websocket-Extensions"] { 852 extensions := strings.Split(extensionList, ",") 853 for _, extension := range extensions { 854 extension = strings.Trim(extension, " \t") 855 params := strings.Split(extension, ";") 856 for i, p := range params { 857 p = strings.Trim(p, " \t") 858 if strings.EqualFold(p, wsPMCExtension) { 859 if checkPMCOnly { 860 return true, false 861 } 862 var snc bool 863 var cnc bool 864 for j := i + 1; j < len(params); j++ { 865 p = params[j] 866 p = strings.Trim(p, " \t") 867 if strings.EqualFold(p, wsPMCSrvNoCtx) { 868 snc = true 869 } else if strings.EqualFold(p, wsPMCCliNoCtx) { 870 cnc = true 871 } 872 if snc && cnc { 873 return true, true 874 } 875 } 876 return true, false 877 } 878 } 879 } 880 } 881 return false, false 882 } 883 884 // Send an HTTP error with the given `status` to the given http response writer `w`. 885 // Return an error created based on the `reason` string. 886 func wsReturnHTTPError(w http.ResponseWriter, r *http.Request, status int, reason string) error { 887 err := fmt.Errorf("%s - websocket handshake error: %s", r.RemoteAddr, reason) 888 w.Header().Set("Sec-Websocket-Version", "13") 889 http.Error(w, http.StatusText(status), status) 890 return err 891 } 892 893 // If the server is configured to accept any origin, then this function returns 894 // `nil` without checking if the Origin is present and valid. This is also 895 // the case if the request does not have the Origin header. 896 // Otherwise, this will check that the Origin matches the same origin or 897 // any origin in the allowed list. 898 func (w *srvWebsocket) checkOrigin(r *http.Request) error { 899 w.mu.RLock() 900 checkSame := w.sameOrigin 901 listEmpty := len(w.allowedOrigins) == 0 902 w.mu.RUnlock() 903 if !checkSame && listEmpty { 904 return nil 905 } 906 origin := r.Header.Get("Origin") 907 if origin == _EMPTY_ { 908 origin = r.Header.Get("Sec-Websocket-Origin") 909 } 910 // If the header is not present, we will accept. 911 // From https://datatracker.ietf.org/doc/html/rfc6455#section-1.6 912 // "Naturally, when the WebSocket Protocol is used by a dedicated client 913 // directly (i.e., not from a web page through a web browser), the origin 914 // model is not useful, as the client can provide any arbitrary origin string." 915 if origin == _EMPTY_ { 916 return nil 917 } 918 u, err := url.ParseRequestURI(origin) 919 if err != nil { 920 return err 921 } 922 oh, op, err := wsGetHostAndPort(u.Scheme == "https", u.Host) 923 if err != nil { 924 return err 925 } 926 // If checking same origin, compare with the http's request's Host. 927 if checkSame { 928 rh, rp, err := wsGetHostAndPort(r.TLS != nil, r.Host) 929 if err != nil { 930 return err 931 } 932 if oh != rh || op != rp { 933 return errors.New("not same origin") 934 } 935 // I guess it is possible to have cases where one wants to check 936 // same origin, but also that the origin is in the allowed list. 937 // So continue with the next check. 938 } 939 if !listEmpty { 940 w.mu.RLock() 941 ao := w.allowedOrigins[oh] 942 w.mu.RUnlock() 943 if ao == nil || u.Scheme != ao.scheme || op != ao.port { 944 return errors.New("not in the allowed list") 945 } 946 } 947 return nil 948 } 949 950 func wsGetHostAndPort(tls bool, hostport string) (string, string, error) { 951 host, port, err := net.SplitHostPort(hostport) 952 if err != nil { 953 // If error is missing port, then use defaults based on the scheme 954 if ae, ok := err.(*net.AddrError); ok && strings.Contains(ae.Err, "missing port") { 955 err = nil 956 host = hostport 957 if tls { 958 port = "443" 959 } else { 960 port = "80" 961 } 962 } 963 } 964 return strings.ToLower(host), port, err 965 } 966 967 // Concatenate the key sent by the client with the GUID, then computes the SHA1 hash 968 // and returns it as a based64 encoded string. 969 func wsAcceptKey(key string) string { 970 h := sha1.New() 971 h.Write([]byte(key)) 972 h.Write(wsGUID) 973 return base64.StdEncoding.EncodeToString(h.Sum(nil)) 974 } 975 976 func wsMakeChallengeKey() (string, error) { 977 p := make([]byte, 16) 978 if _, err := io.ReadFull(crand.Reader, p); err != nil { 979 return _EMPTY_, err 980 } 981 return base64.StdEncoding.EncodeToString(p), nil 982 } 983 984 // Validate the websocket related options. 985 func validateWebsocketOptions(o *Options) error { 986 wo := &o.Websocket 987 // If no port is defined, we don't care about other options 988 if wo.Port == 0 { 989 return nil 990 } 991 // Enforce TLS... unless NoTLS is set to true. 992 if wo.TLSConfig == nil && !wo.NoTLS { 993 return errors.New("websocket requires TLS configuration") 994 } 995 // Make sure that allowed origins, if specified, can be parsed. 996 for _, ao := range wo.AllowedOrigins { 997 if _, err := url.Parse(ao); err != nil { 998 return fmt.Errorf("unable to parse allowed origin: %v", err) 999 } 1000 } 1001 // If there is a NoAuthUser, we need to have Users defined and 1002 // the user to be present. 1003 if wo.NoAuthUser != _EMPTY_ { 1004 if err := validateNoAuthUser(o, wo.NoAuthUser); err != nil { 1005 return err 1006 } 1007 } 1008 // Token/Username not possible if there are users/nkeys 1009 if len(o.Users) > 0 || len(o.Nkeys) > 0 { 1010 if wo.Username != _EMPTY_ { 1011 return fmt.Errorf("websocket authentication username not compatible with presence of users/nkeys") 1012 } 1013 if wo.Token != _EMPTY_ { 1014 return fmt.Errorf("websocket authentication token not compatible with presence of users/nkeys") 1015 } 1016 } 1017 // Using JWT requires Trusted Keys 1018 if wo.JWTCookie != _EMPTY_ { 1019 if len(o.TrustedOperators) == 0 && len(o.TrustedKeys) == 0 { 1020 return fmt.Errorf("trusted operators or trusted keys configuration is required for JWT authentication via cookie %q", wo.JWTCookie) 1021 } 1022 } 1023 if err := validatePinnedCerts(wo.TLSPinnedCerts); err != nil { 1024 return fmt.Errorf("websocket: %v", err) 1025 } 1026 1027 // Check for invalid headers here. 1028 for key := range wo.Headers { 1029 k := strings.ToLower(key) 1030 switch k { 1031 case "host", 1032 "content-length", 1033 "connection", 1034 "upgrade", 1035 "nats-no-masking": 1036 return fmt.Errorf("websocket: invalid header %q not allowed", key) 1037 } 1038 1039 if strings.HasPrefix(k, "sec-websocket-") { 1040 return fmt.Errorf("websocket: invalid header %q, \"Sec-WebSocket-\" prefix not allowed", key) 1041 } 1042 } 1043 1044 return nil 1045 } 1046 1047 // Creates or updates the existing map 1048 func (s *Server) wsSetOriginOptions(o *WebsocketOpts) { 1049 ws := &s.websocket 1050 ws.mu.Lock() 1051 defer ws.mu.Unlock() 1052 // Copy over the option's same origin boolean 1053 ws.sameOrigin = o.SameOrigin 1054 // Reset the map. Will help for config reload if/when we support it. 1055 ws.allowedOrigins = nil 1056 if o.AllowedOrigins == nil { 1057 return 1058 } 1059 for _, ao := range o.AllowedOrigins { 1060 // We have previously checked (during options validation) that the urls 1061 // are parseable, but if we get an error, report and skip. 1062 u, err := url.ParseRequestURI(ao) 1063 if err != nil { 1064 s.Errorf("error parsing allowed origin: %v", err) 1065 continue 1066 } 1067 h, p, _ := wsGetHostAndPort(u.Scheme == "https", u.Host) 1068 if ws.allowedOrigins == nil { 1069 ws.allowedOrigins = make(map[string]*allowedOrigin, len(o.AllowedOrigins)) 1070 } 1071 ws.allowedOrigins[h] = &allowedOrigin{scheme: u.Scheme, port: p} 1072 } 1073 } 1074 1075 // Calculate the raw headers for websocket upgrade response. 1076 func (s *Server) wsSetHeadersOptions(o *WebsocketOpts) { 1077 var sb strings.Builder 1078 for k, v := range o.Headers { 1079 sb.WriteString(k) 1080 sb.WriteString(": ") 1081 sb.WriteString(v) 1082 sb.WriteString(_CRLF_) 1083 } 1084 ws := &s.websocket 1085 ws.mu.Lock() 1086 defer ws.mu.Unlock() 1087 ws.rawHeaders = sb.String() 1088 } 1089 1090 // Given the websocket options, we check if any auth configuration 1091 // has been provided. If so, possibly create users/nkey users and 1092 // store them in s.websocket.users/nkeys. 1093 // Also update a boolean that indicates if auth is required for 1094 // websocket clients. 1095 // Server lock is held on entry. 1096 func (s *Server) wsConfigAuth(opts *WebsocketOpts) { 1097 ws := &s.websocket 1098 // If any of those is specified, we consider that there is an override. 1099 ws.authOverride = opts.Username != _EMPTY_ || opts.Token != _EMPTY_ || opts.NoAuthUser != _EMPTY_ 1100 } 1101 1102 func (s *Server) startWebsocketServer() { 1103 if s.isShuttingDown() { 1104 return 1105 } 1106 1107 sopts := s.getOpts() 1108 o := &sopts.Websocket 1109 1110 s.wsSetOriginOptions(o) 1111 s.wsSetHeadersOptions(o) 1112 1113 var hl net.Listener 1114 var proto string 1115 var err error 1116 1117 port := o.Port 1118 if port == -1 { 1119 port = 0 1120 } 1121 hp := net.JoinHostPort(o.Host, strconv.Itoa(port)) 1122 1123 // We are enforcing (when validating the options) the use of TLS, but the 1124 // code was originally supporting both modes. The reason for TLS only is 1125 // that we expect users to send JWTs with bearer tokens and we want to 1126 // avoid the possibility of it being "intercepted". 1127 1128 s.mu.Lock() 1129 // Do not check o.NoTLS here. If a TLS configuration is available, use it, 1130 // regardless of NoTLS. If we don't have a TLS config, it means that the 1131 // user has configured NoTLS because otherwise the server would have failed 1132 // to start due to options validation. 1133 if o.TLSConfig != nil { 1134 proto = wsSchemePrefixTLS 1135 config := o.TLSConfig.Clone() 1136 config.GetConfigForClient = s.wsGetTLSConfig 1137 hl, err = tls.Listen("tcp", hp, config) 1138 } else { 1139 proto = wsSchemePrefix 1140 hl, err = net.Listen("tcp", hp) 1141 } 1142 s.websocket.listenerErr = err 1143 if err != nil { 1144 s.mu.Unlock() 1145 s.Fatalf("Unable to listen for websocket connections: %v", err) 1146 return 1147 } 1148 if port == 0 { 1149 o.Port = hl.Addr().(*net.TCPAddr).Port 1150 } 1151 s.Noticef("Listening for websocket clients on %s://%s:%d", proto, o.Host, o.Port) 1152 if proto == wsSchemePrefix { 1153 s.Warnf("Websocket not configured with TLS. DO NOT USE IN PRODUCTION!") 1154 } 1155 1156 s.websocket.tls = proto == "wss" 1157 s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port) 1158 if err != nil { 1159 s.Fatalf("Unable to get websocket connect URLs: %v", err) 1160 hl.Close() 1161 s.mu.Unlock() 1162 return 1163 } 1164 hasLeaf := sopts.LeafNode.Port != 0 1165 mux := http.NewServeMux() 1166 mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { 1167 res, err := s.wsUpgrade(w, r) 1168 if err != nil { 1169 s.Errorf(err.Error()) 1170 return 1171 } 1172 switch res.kind { 1173 case CLIENT: 1174 s.createWSClient(res.conn, res.ws) 1175 case MQTT: 1176 s.createMQTTClient(res.conn, res.ws) 1177 case LEAF: 1178 if !hasLeaf { 1179 s.Errorf("Not configured to accept leaf node connections") 1180 // Silently close for now. If we want to send an error back, we would 1181 // need to create the leafnode client anyway, so that is handling websocket 1182 // frames, then send the error to the remote. 1183 res.conn.Close() 1184 return 1185 } 1186 s.createLeafNode(res.conn, nil, nil, res.ws) 1187 } 1188 }) 1189 hs := &http.Server{ 1190 Addr: hp, 1191 Handler: mux, 1192 ReadTimeout: o.HandshakeTimeout, 1193 ErrorLog: log.New(&captureHTTPServerLog{s, "websocket: "}, _EMPTY_, 0), 1194 } 1195 s.websocket.server = hs 1196 s.websocket.listener = hl 1197 go func() { 1198 if err := hs.Serve(hl); err != http.ErrServerClosed { 1199 s.Fatalf("websocket listener error: %v", err) 1200 } 1201 if s.isLameDuckMode() { 1202 // Signal that we are not accepting new clients 1203 s.ldmCh <- true 1204 // Now wait for the Shutdown... 1205 <-s.quitCh 1206 return 1207 } 1208 s.done <- true 1209 }() 1210 s.mu.Unlock() 1211 } 1212 1213 // The TLS configuration is passed to the listener when the websocket 1214 // "server" is setup. That prevents TLS configuration updates on reload 1215 // from being used. By setting this function in tls.Config.GetConfigForClient 1216 // we instruct the TLS handshake to ask for the tls configuration to be 1217 // used for a specific client. We don't care which client, we always use 1218 // the same TLS configuration. 1219 func (s *Server) wsGetTLSConfig(_ *tls.ClientHelloInfo) (*tls.Config, error) { 1220 opts := s.getOpts() 1221 return opts.Websocket.TLSConfig, nil 1222 } 1223 1224 // This is similar to createClient() but has some modifications 1225 // specific to handle websocket clients. 1226 // The comments have been kept to minimum to reduce code size. 1227 // Check createClient() for more details. 1228 func (s *Server) createWSClient(conn net.Conn, ws *websocket) *client { 1229 opts := s.getOpts() 1230 1231 maxPay := int32(opts.MaxPayload) 1232 maxSubs := int32(opts.MaxSubs) 1233 if maxSubs == 0 { 1234 maxSubs = -1 1235 } 1236 now := time.Now().UTC() 1237 1238 c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws} 1239 1240 c.registerWithAccount(s.globalAccount()) 1241 1242 var info Info 1243 var authRequired bool 1244 1245 s.mu.Lock() 1246 info = s.copyInfo() 1247 // Check auth, override if applicable. 1248 if !info.AuthRequired { 1249 // Set info.AuthRequired since this is what is sent to the client. 1250 info.AuthRequired = s.websocket.authOverride 1251 } 1252 if s.nonceRequired() { 1253 var raw [nonceLen]byte 1254 nonce := raw[:] 1255 s.generateNonce(nonce) 1256 info.Nonce = string(nonce) 1257 } 1258 c.nonce = []byte(info.Nonce) 1259 authRequired = info.AuthRequired 1260 1261 s.totalClients++ 1262 s.mu.Unlock() 1263 1264 c.mu.Lock() 1265 if authRequired { 1266 c.flags.set(expectConnect) 1267 } 1268 c.initClient() 1269 c.Debugf("Client connection created") 1270 c.sendProtoNow(c.generateClientInfoJSON(info)) 1271 c.mu.Unlock() 1272 1273 s.mu.Lock() 1274 if !s.isRunning() || s.ldm { 1275 if s.isShuttingDown() { 1276 conn.Close() 1277 } 1278 s.mu.Unlock() 1279 return c 1280 } 1281 1282 if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn { 1283 s.mu.Unlock() 1284 c.maxConnExceeded() 1285 return nil 1286 } 1287 s.clients[c.cid] = c 1288 s.mu.Unlock() 1289 1290 c.mu.Lock() 1291 // Websocket clients do TLS in the websocket http server. 1292 // So no TLS initiation here... 1293 if _, ok := conn.(*tls.Conn); ok { 1294 c.flags.set(handshakeComplete) 1295 } 1296 1297 if c.isClosed() { 1298 c.mu.Unlock() 1299 c.closeConnection(WriteError) 1300 return nil 1301 } 1302 1303 if authRequired { 1304 timeout := opts.AuthTimeout 1305 // Possibly override with Websocket specific value. 1306 if opts.Websocket.AuthTimeout != 0 { 1307 timeout = opts.Websocket.AuthTimeout 1308 } 1309 c.setAuthTimer(secondsToDuration(timeout)) 1310 } 1311 1312 c.setPingTimer() 1313 1314 s.startGoRoutine(func() { c.readLoop(nil) }) 1315 s.startGoRoutine(func() { c.writeLoop() }) 1316 1317 c.mu.Unlock() 1318 1319 return c 1320 } 1321 1322 func (c *client) wsCollapsePtoNB() (net.Buffers, int64) { 1323 nb := c.out.nb 1324 var mfs int 1325 var usz int 1326 if c.ws.browser { 1327 mfs = wsFrameSizeForBrowsers 1328 } 1329 mask := c.ws.maskwrite 1330 // Start with possible already framed buffers (that we could have 1331 // got from partials or control messages such as ws pings or pongs). 1332 bufs := c.ws.frames 1333 compress := c.ws.compress 1334 if compress && len(nb) > 0 { 1335 // First, make sure we don't compress for very small cumulative buffers. 1336 for _, b := range nb { 1337 usz += len(b) 1338 } 1339 if usz <= wsCompressThreshold { 1340 compress = false 1341 } 1342 } 1343 if compress && len(nb) > 0 { 1344 // Overwrite mfs if this connection does not support fragmented compressed frames. 1345 if mfs > 0 && c.ws.nocompfrag { 1346 mfs = 0 1347 } 1348 buf := bytes.NewBuffer(nbPoolGet(usz)) 1349 cp := c.ws.compressor 1350 if cp == nil { 1351 c.ws.compressor, _ = flate.NewWriter(buf, flate.BestSpeed) 1352 cp = c.ws.compressor 1353 } else { 1354 cp.Reset(buf) 1355 } 1356 var csz int 1357 for _, b := range nb { 1358 cp.Write(b) 1359 nbPoolPut(b) // No longer needed as contents written to compressor. 1360 } 1361 if err := cp.Flush(); err != nil { 1362 c.Errorf("Error during compression: %v", err) 1363 c.markConnAsClosed(WriteError) 1364 return nil, 0 1365 } 1366 b := buf.Bytes() 1367 p := b[:len(b)-4] 1368 if mfs > 0 && len(p) > mfs { 1369 for first, final := true, false; len(p) > 0; first = false { 1370 lp := len(p) 1371 if lp > mfs { 1372 lp = mfs 1373 } else { 1374 final = true 1375 } 1376 // Only the first frame should be marked as compressed, so pass 1377 // `first` for the compressed boolean. 1378 fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize] 1379 n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp) 1380 if mask { 1381 wsMaskBuf(key, p[:lp]) 1382 } 1383 bufs = append(bufs, fh[:n], p[:lp]) 1384 csz += n + lp 1385 p = p[lp:] 1386 } 1387 } else { 1388 ol := len(p) 1389 h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, ol) 1390 if mask { 1391 wsMaskBuf(key, p) 1392 } 1393 if ol > 0 { 1394 bufs = append(bufs, h, p) 1395 } 1396 csz = len(h) + ol 1397 } 1398 // Make sure that the compressor no longer holds a reference to 1399 // the bytes.Buffer, so that the underlying memory gets cleaned 1400 // up after flushOutbound/flushAndClose. For this to be safe, we 1401 // always cp.Reset(...) before reusing the compressor again. 1402 cp.Reset(nil) 1403 // Add to pb the compressed data size (including headers), but 1404 // remove the original uncompressed data size that was added 1405 // during the queueing. 1406 c.out.pb += int64(csz) - int64(usz) 1407 c.ws.fs += int64(csz) 1408 } else if len(nb) > 0 { 1409 var total int 1410 if mfs > 0 { 1411 // We are limiting the frame size. 1412 startFrame := func() int { 1413 bufs = append(bufs, nbPoolGet(wsMaxFrameHeaderSize)) 1414 return len(bufs) - 1 1415 } 1416 endFrame := func(idx, size int) { 1417 bufs[idx] = bufs[idx][:wsMaxFrameHeaderSize] 1418 n, key := wsFillFrameHeader(bufs[idx], mask, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, wsBinaryMessage, size) 1419 bufs[idx] = bufs[idx][:n] 1420 c.out.pb += int64(n) 1421 c.ws.fs += int64(n + size) 1422 if mask { 1423 wsMaskBufs(key, bufs[idx+1:]) 1424 } 1425 } 1426 1427 fhIdx := startFrame() 1428 for i := 0; i < len(nb); i++ { 1429 b := nb[i] 1430 if total+len(b) <= mfs { 1431 buf := nbPoolGet(len(b)) 1432 bufs = append(bufs, append(buf, b...)) 1433 total += len(b) 1434 nbPoolPut(nb[i]) 1435 continue 1436 } 1437 for len(b) > 0 { 1438 endStart := total != 0 1439 if endStart { 1440 endFrame(fhIdx, total) 1441 } 1442 total = len(b) 1443 if total >= mfs { 1444 total = mfs 1445 } 1446 if endStart { 1447 fhIdx = startFrame() 1448 } 1449 buf := nbPoolGet(total) 1450 bufs = append(bufs, append(buf, b[:total]...)) 1451 b = b[total:] 1452 } 1453 nbPoolPut(nb[i]) // No longer needed as copied into smaller frames. 1454 } 1455 if total > 0 { 1456 endFrame(fhIdx, total) 1457 } 1458 } else { 1459 // If there is no limit on the frame size, create a single frame for 1460 // all pending buffers. 1461 for _, b := range nb { 1462 total += len(b) 1463 } 1464 wsfh, key := wsCreateFrameHeader(mask, false, wsBinaryMessage, total) 1465 c.out.pb += int64(len(wsfh)) 1466 bufs = append(bufs, wsfh) 1467 idx := len(bufs) 1468 bufs = append(bufs, nb...) 1469 if mask { 1470 wsMaskBufs(key, bufs[idx:]) 1471 } 1472 c.ws.fs += int64(len(wsfh) + total) 1473 } 1474 } 1475 if len(c.ws.closeMsg) > 0 { 1476 bufs = append(bufs, c.ws.closeMsg) 1477 c.ws.fs += int64(len(c.ws.closeMsg)) 1478 c.ws.closeMsg = nil 1479 } 1480 c.ws.frames = nil 1481 return bufs, c.ws.fs 1482 } 1483 1484 func isWSURL(u *url.URL) bool { 1485 return strings.HasPrefix(strings.ToLower(u.Scheme), wsSchemePrefix) 1486 } 1487 1488 func isWSSURL(u *url.URL) bool { 1489 return strings.HasPrefix(strings.ToLower(u.Scheme), wsSchemePrefixTLS) 1490 }