get.pme.sh/pnats@v0.0.0-20240304004023-26bb5a137ed0/server/websocket.go (about)

     1  // Copyright 2020-2024 The NATS Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package server
    15  
    16  import (
    17  	"bytes"
    18  	crand "crypto/rand"
    19  	"crypto/sha1"
    20  	"crypto/tls"
    21  	"encoding/base64"
    22  	"encoding/binary"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"log"
    27  	mrand "math/rand"
    28  	"net"
    29  	"net/http"
    30  	"net/url"
    31  	"strconv"
    32  	"strings"
    33  	"sync"
    34  	"time"
    35  	"unicode/utf8"
    36  
    37  	"github.com/klauspost/compress/flate"
    38  )
    39  
    40  type wsOpCode int
    41  
    42  const (
    43  	// From https://tools.ietf.org/html/rfc6455#section-5.2
    44  	wsTextMessage   = wsOpCode(1)
    45  	wsBinaryMessage = wsOpCode(2)
    46  	wsCloseMessage  = wsOpCode(8)
    47  	wsPingMessage   = wsOpCode(9)
    48  	wsPongMessage   = wsOpCode(10)
    49  
    50  	wsFinalBit = 1 << 7
    51  	wsRsv1Bit  = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6
    52  	wsRsv2Bit  = 1 << 5
    53  	wsRsv3Bit  = 1 << 4
    54  
    55  	wsMaskBit = 1 << 7
    56  
    57  	wsContinuationFrame     = 0
    58  	wsMaxFrameHeaderSize    = 14 // Since LeafNode may need to behave as a client
    59  	wsMaxControlPayloadSize = 125
    60  	wsFrameSizeForBrowsers  = 4096 // From experiment, webrowsers behave better with limited frame size
    61  	wsCompressThreshold     = 64   // Don't compress for small buffer(s)
    62  	wsCloseSatusSize        = 2
    63  
    64  	// From https://tools.ietf.org/html/rfc6455#section-11.7
    65  	wsCloseStatusNormalClosure      = 1000
    66  	wsCloseStatusGoingAway          = 1001
    67  	wsCloseStatusProtocolError      = 1002
    68  	wsCloseStatusUnsupportedData    = 1003
    69  	wsCloseStatusNoStatusReceived   = 1005
    70  	wsCloseStatusAbnormalClosure    = 1006
    71  	wsCloseStatusInvalidPayloadData = 1007
    72  	wsCloseStatusPolicyViolation    = 1008
    73  	wsCloseStatusMessageTooBig      = 1009
    74  	wsCloseStatusInternalSrvError   = 1011
    75  	wsCloseStatusTLSHandshake       = 1015
    76  
    77  	wsFirstFrame        = true
    78  	wsContFrame         = false
    79  	wsFinalFrame        = true
    80  	wsUncompressedFrame = false
    81  
    82  	wsSchemePrefix    = "ws"
    83  	wsSchemePrefixTLS = "wss"
    84  
    85  	wsNoMaskingHeader       = "Nats-No-Masking"
    86  	wsNoMaskingValue        = "true"
    87  	wsXForwardedForHeader   = "X-Forwarded-For"
    88  	wsNoMaskingFullResponse = wsNoMaskingHeader + ": " + wsNoMaskingValue + CR_LF
    89  	wsPMCExtension          = "permessage-deflate" // per-message compression
    90  	wsPMCSrvNoCtx           = "server_no_context_takeover"
    91  	wsPMCCliNoCtx           = "client_no_context_takeover"
    92  	wsPMCReqHeaderValue     = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx
    93  	wsPMCFullResponse       = "Sec-WebSocket-Extensions: " + wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx + _CRLF_
    94  	wsSecProto              = "Sec-Websocket-Protocol"
    95  	wsMQTTSecProtoVal       = "mqtt"
    96  	wsMQTTSecProto          = wsSecProto + ": " + wsMQTTSecProtoVal + CR_LF
    97  )
    98  
    99  var decompressorPool sync.Pool
   100  var compressLastBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}
   101  
   102  // From https://tools.ietf.org/html/rfc6455#section-1.3
   103  var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
   104  
   105  // Test can enable this so that server does not support "no-masking" requests.
   106  var wsTestRejectNoMasking = false
   107  
   108  type websocket struct {
   109  	frames         net.Buffers
   110  	fs             int64
   111  	closeMsg       []byte
   112  	compress       bool
   113  	closeSent      bool
   114  	browser        bool
   115  	nocompfrag     bool // No fragment for compressed frames
   116  	maskread       bool
   117  	maskwrite      bool
   118  	compressor     *flate.Writer
   119  	cookieJwt      string
   120  	cookieUsername string
   121  	cookiePassword string
   122  	cookieToken    string
   123  	clientIP       string
   124  }
   125  
   126  type srvWebsocket struct {
   127  	mu             sync.RWMutex
   128  	server         *http.Server
   129  	listener       net.Listener
   130  	listenerErr    error
   131  	tls            bool
   132  	allowedOrigins map[string]*allowedOrigin // host will be the key
   133  	sameOrigin     bool
   134  	connectURLs    []string
   135  	connectURLsMap refCountedUrlSet
   136  	authOverride   bool // indicate if there is auth override in websocket config
   137  }
   138  
   139  type allowedOrigin struct {
   140  	scheme string
   141  	port   string
   142  }
   143  
   144  type wsUpgradeResult struct {
   145  	conn net.Conn
   146  	ws   *websocket
   147  	kind int
   148  }
   149  
   150  type wsReadInfo struct {
   151  	rem   int
   152  	fs    bool
   153  	ff    bool
   154  	fc    bool
   155  	mask  bool // Incoming leafnode connections may not have masking.
   156  	mkpos byte
   157  	mkey  [4]byte
   158  	cbufs [][]byte
   159  	coff  int
   160  }
   161  
   162  func (r *wsReadInfo) init() {
   163  	r.fs, r.ff = true, true
   164  }
   165  
   166  // Returns a slice containing `needed` bytes from the given buffer `buf`
   167  // starting at position `pos`, and possibly read from the given reader `r`.
   168  // When bytes are present in `buf`, the `pos` is incremented by the number
   169  // of bytes found up to `needed` and the new position is returned. If not
   170  // enough bytes are found, the bytes found in `buf` are copied to the returned
   171  // slice and the remaning bytes are read from `r`.
   172  func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) {
   173  	avail := len(buf) - pos
   174  	if avail >= needed {
   175  		return buf[pos : pos+needed], pos + needed, nil
   176  	}
   177  	b := make([]byte, needed)
   178  	start := copy(b, buf[pos:])
   179  	for start != needed {
   180  		n, err := r.Read(b[start:cap(b)])
   181  		if err != nil {
   182  			return nil, 0, err
   183  		}
   184  		start += n
   185  	}
   186  	return b, pos + avail, nil
   187  }
   188  
   189  // Returns true if this connection is from a Websocket client.
   190  // Lock held on entry.
   191  func (c *client) isWebsocket() bool {
   192  	return c.ws != nil
   193  }
   194  
   195  // Returns a slice of byte slices corresponding to payload of websocket frames.
   196  // The byte slice `buf` is filled with bytes from the connection's read loop.
   197  // This function will decode the frame headers and unmask the payload(s).
   198  // It is possible that the returned slices point to the given `buf` slice, so
   199  // `buf` should not be overwritten until the returned slices have been parsed.
   200  //
   201  // Client lock MUST NOT be held on entry.
   202  func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, error) {
   203  	var (
   204  		bufs   [][]byte
   205  		tmpBuf []byte
   206  		err    error
   207  		pos    int
   208  		max    = len(buf)
   209  	)
   210  	for pos != max {
   211  		if r.fs {
   212  			b0 := buf[pos]
   213  			frameType := wsOpCode(b0 & 0xF)
   214  			final := b0&wsFinalBit != 0
   215  			compressed := b0&wsRsv1Bit != 0
   216  			pos++
   217  
   218  			tmpBuf, pos, err = wsGet(ior, buf, pos, 1)
   219  			if err != nil {
   220  				return bufs, err
   221  			}
   222  			b1 := tmpBuf[0]
   223  
   224  			// Clients MUST set the mask bit. If not set, reject.
   225  			// However, LEAF by default will not have masking, unless they are forced to, by configuration.
   226  			if r.mask && b1&wsMaskBit == 0 {
   227  				return bufs, c.wsHandleProtocolError("mask bit missing")
   228  			}
   229  
   230  			// Store size in case it is < 125
   231  			r.rem = int(b1 & 0x7F)
   232  
   233  			switch frameType {
   234  			case wsPingMessage, wsPongMessage, wsCloseMessage:
   235  				if r.rem > wsMaxControlPayloadSize {
   236  					return bufs, c.wsHandleProtocolError(
   237  						fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes",
   238  							wsMaxControlPayloadSize))
   239  				}
   240  				if !final {
   241  					return bufs, c.wsHandleProtocolError("control frame does not have final bit set")
   242  				}
   243  			case wsTextMessage, wsBinaryMessage:
   244  				if !r.ff {
   245  					return bufs, c.wsHandleProtocolError("new message started before final frame for previous message was received")
   246  				}
   247  				r.ff = final
   248  				r.fc = compressed
   249  			case wsContinuationFrame:
   250  				// Compressed bit must be only set in the first frame
   251  				if r.ff || compressed {
   252  					return bufs, c.wsHandleProtocolError("invalid continuation frame")
   253  				}
   254  				r.ff = final
   255  			default:
   256  				return bufs, c.wsHandleProtocolError(fmt.Sprintf("unknown opcode %v", frameType))
   257  			}
   258  
   259  			switch r.rem {
   260  			case 126:
   261  				tmpBuf, pos, err = wsGet(ior, buf, pos, 2)
   262  				if err != nil {
   263  					return bufs, err
   264  				}
   265  				r.rem = int(binary.BigEndian.Uint16(tmpBuf))
   266  			case 127:
   267  				tmpBuf, pos, err = wsGet(ior, buf, pos, 8)
   268  				if err != nil {
   269  					return bufs, err
   270  				}
   271  				r.rem = int(binary.BigEndian.Uint64(tmpBuf))
   272  			}
   273  
   274  			if r.mask {
   275  				// Read masking key
   276  				tmpBuf, pos, err = wsGet(ior, buf, pos, 4)
   277  				if err != nil {
   278  					return bufs, err
   279  				}
   280  				copy(r.mkey[:], tmpBuf)
   281  				r.mkpos = 0
   282  			}
   283  
   284  			// Handle control messages in place...
   285  			if wsIsControlFrame(frameType) {
   286  				pos, err = c.wsHandleControlFrame(r, frameType, ior, buf, pos)
   287  				if err != nil {
   288  					return bufs, err
   289  				}
   290  				continue
   291  			}
   292  
   293  			// Done with the frame header
   294  			r.fs = false
   295  		}
   296  		if pos < max {
   297  			var b []byte
   298  			var n int
   299  
   300  			n = r.rem
   301  			if pos+n > max {
   302  				n = max - pos
   303  			}
   304  			b = buf[pos : pos+n]
   305  			pos += n
   306  			r.rem -= n
   307  			// If needed, unmask the buffer
   308  			if r.mask {
   309  				r.unmask(b)
   310  			}
   311  			addToBufs := true
   312  			// Handle compressed message
   313  			if r.fc {
   314  				// Assume that we may have continuation frames or not the full payload.
   315  				addToBufs = false
   316  				// Make a copy of the buffer before adding it to the list
   317  				// of compressed fragments.
   318  				r.cbufs = append(r.cbufs, append([]byte(nil), b...))
   319  				// When we have the final frame and we have read the full payload,
   320  				// we can decompress it.
   321  				if r.ff && r.rem == 0 {
   322  					b, err = r.decompress()
   323  					if err != nil {
   324  						return bufs, err
   325  					}
   326  					r.fc = false
   327  					// Now we can add to `bufs`
   328  					addToBufs = true
   329  				}
   330  			}
   331  			// For non compressed frames, or when we have decompressed the
   332  			// whole message.
   333  			if addToBufs {
   334  				bufs = append(bufs, b)
   335  			}
   336  			// If payload has been fully read, then indicate that next
   337  			// is the start of a frame.
   338  			if r.rem == 0 {
   339  				r.fs = true
   340  			}
   341  		}
   342  	}
   343  	return bufs, nil
   344  }
   345  
   346  func (r *wsReadInfo) Read(dst []byte) (int, error) {
   347  	if len(dst) == 0 {
   348  		return 0, nil
   349  	}
   350  	if len(r.cbufs) == 0 {
   351  		return 0, io.EOF
   352  	}
   353  	copied := 0
   354  	rem := len(dst)
   355  	for buf := r.cbufs[0]; buf != nil && rem > 0; {
   356  		n := len(buf[r.coff:])
   357  		if n > rem {
   358  			n = rem
   359  		}
   360  		copy(dst[copied:], buf[r.coff:r.coff+n])
   361  		copied += n
   362  		rem -= n
   363  		r.coff += n
   364  		buf = r.nextCBuf()
   365  	}
   366  	return copied, nil
   367  }
   368  
   369  func (r *wsReadInfo) nextCBuf() []byte {
   370  	// We still have remaining data in the first buffer
   371  	if r.coff != len(r.cbufs[0]) {
   372  		return r.cbufs[0]
   373  	}
   374  	// We read the full first buffer. Reset offset.
   375  	r.coff = 0
   376  	// We were at the last buffer, so we are done.
   377  	if len(r.cbufs) == 1 {
   378  		r.cbufs = nil
   379  		return nil
   380  	}
   381  	// Here we move to the next buffer.
   382  	r.cbufs = r.cbufs[1:]
   383  	return r.cbufs[0]
   384  }
   385  
   386  func (r *wsReadInfo) ReadByte() (byte, error) {
   387  	if len(r.cbufs) == 0 {
   388  		return 0, io.EOF
   389  	}
   390  	b := r.cbufs[0][r.coff]
   391  	r.coff++
   392  	r.nextCBuf()
   393  	return b, nil
   394  }
   395  
   396  func (r *wsReadInfo) decompress() ([]byte, error) {
   397  	r.coff = 0
   398  	// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
   399  	// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
   400  	// does not report unexpected EOF.
   401  	r.cbufs = append(r.cbufs, compressLastBlock)
   402  	// Get a decompressor from the pool and bind it to this object (wsReadInfo)
   403  	// that provides Read() and ReadByte() APIs that will consume the compressed
   404  	// buffers (r.cbufs).
   405  	d, _ := decompressorPool.Get().(io.ReadCloser)
   406  	if d == nil {
   407  		d = flate.NewReader(r)
   408  	} else {
   409  		d.(flate.Resetter).Reset(r, nil)
   410  	}
   411  	// This will do the decompression.
   412  	b, err := io.ReadAll(d)
   413  	decompressorPool.Put(d)
   414  	// Now reset the compressed buffers list.
   415  	r.cbufs = nil
   416  	return b, err
   417  }
   418  
   419  // Handles the PING, PONG and CLOSE websocket control frames.
   420  //
   421  // Client lock MUST NOT be held on entry.
   422  func (c *client) wsHandleControlFrame(r *wsReadInfo, frameType wsOpCode, nc io.Reader, buf []byte, pos int) (int, error) {
   423  	var payload []byte
   424  	var err error
   425  
   426  	if r.rem > 0 {
   427  		payload, pos, err = wsGet(nc, buf, pos, r.rem)
   428  		if err != nil {
   429  			return pos, err
   430  		}
   431  		if r.mask {
   432  			r.unmask(payload)
   433  		}
   434  		r.rem = 0
   435  	}
   436  	switch frameType {
   437  	case wsCloseMessage:
   438  		status := wsCloseStatusNoStatusReceived
   439  		var body string
   440  		lp := len(payload)
   441  		// If there is a payload, the status is represented as a 2-byte
   442  		// unsigned integer (in network byte order). Then, there may be an
   443  		// optional body.
   444  		hasStatus, hasBody := lp >= wsCloseSatusSize, lp > wsCloseSatusSize
   445  		if hasStatus {
   446  			// Decode the status
   447  			status = int(binary.BigEndian.Uint16(payload[:wsCloseSatusSize]))
   448  			// Now if there is a body, capture it and make sure this is a valid UTF-8.
   449  			if hasBody {
   450  				body = string(payload[wsCloseSatusSize:])
   451  				if !utf8.ValidString(body) {
   452  					// https://tools.ietf.org/html/rfc6455#section-5.5.1
   453  					// If body is present, it must be a valid utf8
   454  					status = wsCloseStatusInvalidPayloadData
   455  					body = "invalid utf8 body in close frame"
   456  				}
   457  			}
   458  		}
   459  		clm := wsCreateCloseMessage(status, body)
   460  		c.wsEnqueueControlMessage(wsCloseMessage, clm)
   461  		nbPoolPut(clm) // wsEnqueueControlMessage has taken a copy.
   462  		// Return io.EOF so that readLoop will close the connection as ClientClosed
   463  		// after processing pending buffers.
   464  		return pos, io.EOF
   465  	case wsPingMessage:
   466  		c.wsEnqueueControlMessage(wsPongMessage, payload)
   467  	case wsPongMessage:
   468  		// Nothing to do..
   469  	}
   470  	return pos, nil
   471  }
   472  
   473  // Unmask the given slice.
   474  func (r *wsReadInfo) unmask(buf []byte) {
   475  	p := int(r.mkpos)
   476  	if len(buf) < 16 {
   477  		for i := 0; i < len(buf); i++ {
   478  			buf[i] ^= r.mkey[p&3]
   479  			p++
   480  		}
   481  		r.mkpos = byte(p & 3)
   482  		return
   483  	}
   484  	var k [8]byte
   485  	for i := 0; i < 8; i++ {
   486  		k[i] = r.mkey[(p+i)&3]
   487  	}
   488  	km := binary.BigEndian.Uint64(k[:])
   489  	n := (len(buf) / 8) * 8
   490  	for i := 0; i < n; i += 8 {
   491  		tmp := binary.BigEndian.Uint64(buf[i : i+8])
   492  		tmp ^= km
   493  		binary.BigEndian.PutUint64(buf[i:], tmp)
   494  	}
   495  	buf = buf[n:]
   496  	for i := 0; i < len(buf); i++ {
   497  		buf[i] ^= r.mkey[p&3]
   498  		p++
   499  	}
   500  	r.mkpos = byte(p & 3)
   501  }
   502  
   503  // Returns true if the op code corresponds to a control frame.
   504  func wsIsControlFrame(frameType wsOpCode) bool {
   505  	return frameType >= wsCloseMessage
   506  }
   507  
   508  // Create the frame header.
   509  // Encodes the frame type and optional compression flag, and the size of the payload.
   510  func wsCreateFrameHeader(useMasking, compressed bool, frameType wsOpCode, l int) ([]byte, []byte) {
   511  	fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize]
   512  	n, key := wsFillFrameHeader(fh, useMasking, wsFirstFrame, wsFinalFrame, compressed, frameType, l)
   513  	return fh[:n], key
   514  }
   515  
   516  func wsFillFrameHeader(fh []byte, useMasking, first, final, compressed bool, frameType wsOpCode, l int) (int, []byte) {
   517  	var n int
   518  	var b byte
   519  	if first {
   520  		b = byte(frameType)
   521  	}
   522  	if final {
   523  		b |= wsFinalBit
   524  	}
   525  	if compressed {
   526  		b |= wsRsv1Bit
   527  	}
   528  	b1 := byte(0)
   529  	if useMasking {
   530  		b1 |= wsMaskBit
   531  	}
   532  	switch {
   533  	case l <= 125:
   534  		n = 2
   535  		fh[0] = b
   536  		fh[1] = b1 | byte(l)
   537  	case l < 65536:
   538  		n = 4
   539  		fh[0] = b
   540  		fh[1] = b1 | 126
   541  		binary.BigEndian.PutUint16(fh[2:], uint16(l))
   542  	default:
   543  		n = 10
   544  		fh[0] = b
   545  		fh[1] = b1 | 127
   546  		binary.BigEndian.PutUint64(fh[2:], uint64(l))
   547  	}
   548  	var key []byte
   549  	if useMasking {
   550  		var keyBuf [4]byte
   551  		if _, err := io.ReadFull(crand.Reader, keyBuf[:4]); err != nil {
   552  			kv := mrand.Int31()
   553  			binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv))
   554  		}
   555  		copy(fh[n:], keyBuf[:4])
   556  		key = fh[n : n+4]
   557  		n += 4
   558  	}
   559  	return n, key
   560  }
   561  
   562  // Invokes wsEnqueueControlMessageLocked under client lock.
   563  //
   564  // Client lock MUST NOT be held on entry
   565  func (c *client) wsEnqueueControlMessage(controlMsg wsOpCode, payload []byte) {
   566  	c.mu.Lock()
   567  	c.wsEnqueueControlMessageLocked(controlMsg, payload)
   568  	c.mu.Unlock()
   569  }
   570  
   571  // Mask the buffer with the given key
   572  func wsMaskBuf(key, buf []byte) {
   573  	for i := 0; i < len(buf); i++ {
   574  		buf[i] ^= key[i&3]
   575  	}
   576  }
   577  
   578  // Mask the buffers, as if they were contiguous, with the given key
   579  func wsMaskBufs(key []byte, bufs [][]byte) {
   580  	pos := 0
   581  	for i := 0; i < len(bufs); i++ {
   582  		buf := bufs[i]
   583  		for j := 0; j < len(buf); j++ {
   584  			buf[j] ^= key[pos&3]
   585  			pos++
   586  		}
   587  	}
   588  }
   589  
   590  // Enqueues a websocket control message.
   591  // If the control message is a wsCloseMessage, then marks this client
   592  // has having sent the close message (since only one should be sent).
   593  // This will prevent the generic closeConnection() to enqueue one.
   594  //
   595  // Client lock held on entry.
   596  func (c *client) wsEnqueueControlMessageLocked(controlMsg wsOpCode, payload []byte) {
   597  	// Control messages are never compressed and their size will be
   598  	// less than wsMaxControlPayloadSize, which means the frame header
   599  	// will be only 2 or 6 bytes.
   600  	useMasking := c.ws.maskwrite
   601  	sz := 2
   602  	if useMasking {
   603  		sz += 4
   604  	}
   605  	cm := nbPoolGet(sz + len(payload))
   606  	cm = cm[:cap(cm)]
   607  	n, key := wsFillFrameHeader(cm, useMasking, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, controlMsg, len(payload))
   608  	cm = cm[:n]
   609  	// Note that payload is optional.
   610  	if len(payload) > 0 {
   611  		cm = append(cm, payload...)
   612  		if useMasking {
   613  			wsMaskBuf(key, cm[n:])
   614  		}
   615  	}
   616  	c.out.pb += int64(len(cm))
   617  	if controlMsg == wsCloseMessage {
   618  		// We can't add the close message to the frames buffers
   619  		// now. It will be done on a flushOutbound() when there
   620  		// are no more pending buffers to send.
   621  		c.ws.closeSent = true
   622  		c.ws.closeMsg = cm
   623  	} else {
   624  		c.ws.frames = append(c.ws.frames, cm)
   625  		c.ws.fs += int64(len(cm))
   626  	}
   627  	c.flushSignal()
   628  }
   629  
   630  // Enqueues a websocket close message with a status mapped from the given `reason`.
   631  //
   632  // Client lock held on entry
   633  func (c *client) wsEnqueueCloseMessage(reason ClosedState) {
   634  	var status int
   635  	switch reason {
   636  	case ClientClosed:
   637  		status = wsCloseStatusNormalClosure
   638  	case AuthenticationTimeout, AuthenticationViolation, SlowConsumerPendingBytes, SlowConsumerWriteDeadline,
   639  		MaxAccountConnectionsExceeded, MaxConnectionsExceeded, MaxControlLineExceeded, MaxSubscriptionsExceeded,
   640  		MissingAccount, AuthenticationExpired, Revocation:
   641  		status = wsCloseStatusPolicyViolation
   642  	case TLSHandshakeError:
   643  		status = wsCloseStatusTLSHandshake
   644  	case ParseError, ProtocolViolation, BadClientProtocolVersion:
   645  		status = wsCloseStatusProtocolError
   646  	case MaxPayloadExceeded:
   647  		status = wsCloseStatusMessageTooBig
   648  	case ServerShutdown:
   649  		status = wsCloseStatusGoingAway
   650  	case WriteError, ReadError, StaleConnection:
   651  		status = wsCloseStatusAbnormalClosure
   652  	default:
   653  		status = wsCloseStatusInternalSrvError
   654  	}
   655  	body := wsCreateCloseMessage(status, reason.String())
   656  	c.wsEnqueueControlMessageLocked(wsCloseMessage, body)
   657  	nbPoolPut(body) // wsEnqueueControlMessageLocked has taken a copy.
   658  }
   659  
   660  // Create and then enqueue a close message with a protocol error and the
   661  // given message. This is invoked when parsing websocket frames.
   662  //
   663  // Lock MUST NOT be held on entry.
   664  func (c *client) wsHandleProtocolError(message string) error {
   665  	buf := wsCreateCloseMessage(wsCloseStatusProtocolError, message)
   666  	c.wsEnqueueControlMessage(wsCloseMessage, buf)
   667  	nbPoolPut(buf) // wsEnqueueControlMessage has taken a copy.
   668  	return fmt.Errorf(message)
   669  }
   670  
   671  // Create a close message with the given `status` and `body`.
   672  // If the `body` is more than the maximum allows control frame payload size,
   673  // it is truncated and "..." is added at the end (as a hint that message
   674  // is not complete).
   675  func wsCreateCloseMessage(status int, body string) []byte {
   676  	// Since a control message payload is limited in size, we
   677  	// will limit the text and add trailing "..." if truncated.
   678  	// The body of a Close Message must be preceded with 2 bytes,
   679  	// so take that into account for limiting the body length.
   680  	if len(body) > wsMaxControlPayloadSize-2 {
   681  		body = body[:wsMaxControlPayloadSize-5]
   682  		body += "..."
   683  	}
   684  	buf := nbPoolGet(2 + len(body))[:2+len(body)]
   685  	// We need to have a 2 byte unsigned int that represents the error status code
   686  	// https://tools.ietf.org/html/rfc6455#section-5.5.1
   687  	binary.BigEndian.PutUint16(buf[:2], uint16(status))
   688  	copy(buf[2:], []byte(body))
   689  	return buf
   690  }
   691  
   692  // Process websocket client handshake. On success, returns the raw net.Conn that
   693  // will be used to create a *client object.
   694  // Invoked from the HTTP server listening on websocket port.
   695  func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeResult, error) {
   696  	kind := CLIENT
   697  	if r.URL != nil {
   698  		ep := r.URL.EscapedPath()
   699  		if strings.HasSuffix(ep, leafNodeWSPath) {
   700  			kind = LEAF
   701  		} else if strings.HasSuffix(ep, mqttWSPath) {
   702  			kind = MQTT
   703  		}
   704  	}
   705  
   706  	opts := s.getOpts()
   707  
   708  	// From https://tools.ietf.org/html/rfc6455#section-4.2.1
   709  	// Point 1.
   710  	if r.Method != "GET" {
   711  		return nil, wsReturnHTTPError(w, r, http.StatusMethodNotAllowed, "request method must be GET")
   712  	}
   713  	// Point 2.
   714  	if r.Host == _EMPTY_ {
   715  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "'Host' missing in request")
   716  	}
   717  	// Point 3.
   718  	if !wsHeaderContains(r.Header, "Upgrade", "websocket") {
   719  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Upgrade'")
   720  	}
   721  	// Point 4.
   722  	if !wsHeaderContains(r.Header, "Connection", "Upgrade") {
   723  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Connection'")
   724  	}
   725  	// Point 5.
   726  	key := r.Header.Get("Sec-Websocket-Key")
   727  	if key == _EMPTY_ {
   728  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "key missing")
   729  	}
   730  	// Point 6.
   731  	if !wsHeaderContains(r.Header, "Sec-Websocket-Version", "13") {
   732  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid version")
   733  	}
   734  	// Others are optional
   735  	// Point 7.
   736  	if err := s.websocket.checkOrigin(r); err != nil {
   737  		return nil, wsReturnHTTPError(w, r, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err))
   738  	}
   739  	// Point 8.
   740  	// We don't have protocols, so ignore.
   741  	// Point 9.
   742  	// Extensions, only support for compression at the moment
   743  	compress := opts.Websocket.Compression
   744  	if compress {
   745  		// Simply check if permessage-deflate extension is present.
   746  		compress, _ = wsPMCExtensionSupport(r.Header, true)
   747  	}
   748  	// We will do masking if asked (unless we reject for tests)
   749  	noMasking := r.Header.Get(wsNoMaskingHeader) == wsNoMaskingValue && !wsTestRejectNoMasking
   750  
   751  	h := w.(http.Hijacker)
   752  	conn, brw, err := h.Hijack()
   753  	if err != nil {
   754  		if conn != nil {
   755  			conn.Close()
   756  		}
   757  		return nil, wsReturnHTTPError(w, r, http.StatusInternalServerError, err.Error())
   758  	}
   759  	if brw.Reader.Buffered() > 0 {
   760  		conn.Close()
   761  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "client sent data before handshake is complete")
   762  	}
   763  
   764  	var buf [1024]byte
   765  	p := buf[:0]
   766  
   767  	// From https://tools.ietf.org/html/rfc6455#section-4.2.2
   768  	p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
   769  	p = append(p, wsAcceptKey(key)...)
   770  	p = append(p, _CRLF_...)
   771  	if compress {
   772  		p = append(p, wsPMCFullResponse...)
   773  	}
   774  	if noMasking {
   775  		p = append(p, wsNoMaskingFullResponse...)
   776  	}
   777  	if kind == MQTT {
   778  		p = append(p, wsMQTTSecProto...)
   779  	}
   780  	p = append(p, _CRLF_...)
   781  
   782  	if _, err = conn.Write(p); err != nil {
   783  		conn.Close()
   784  		return nil, err
   785  	}
   786  	// If there was a deadline set for the handshake, clear it now.
   787  	if opts.Websocket.HandshakeTimeout > 0 {
   788  		conn.SetDeadline(time.Time{})
   789  	}
   790  	// Server always expect "clients" to send masked payload, unless the option
   791  	// "no-masking" has been enabled.
   792  	ws := &websocket{compress: compress, maskread: !noMasking}
   793  
   794  	// Check for X-Forwarded-For header
   795  	if cips, ok := r.Header[wsXForwardedForHeader]; ok {
   796  		cip := cips[0]
   797  		if net.ParseIP(cip) != nil {
   798  			ws.clientIP = cip
   799  		}
   800  	}
   801  
   802  	if kind == CLIENT || kind == MQTT {
   803  		// Indicate if this is likely coming from a browser.
   804  		if ua := r.Header.Get("User-Agent"); ua != _EMPTY_ && strings.HasPrefix(ua, "Mozilla/") {
   805  			ws.browser = true
   806  			// Disable fragmentation of compressed frames for Safari browsers.
   807  			// Unfortunately, you could be running Chrome on macOS and this
   808  			// string will contain "Safari/" (along "Chrome/"). However, what
   809  			// I have found is that actual Safari browser also have "Version/".
   810  			// So make the combination of the two.
   811  			ws.nocompfrag = ws.compress && strings.Contains(ua, "Version/") && strings.Contains(ua, "Safari/")
   812  		}
   813  
   814  		if cookies := r.Cookies(); len(cookies) > 0 {
   815  			ows := &opts.Websocket
   816  			for _, c := range cookies {
   817  				if ows.JWTCookie == c.Name {
   818  					ws.cookieJwt = c.Value
   819  				} else if ows.UsernameCookie == c.Name {
   820  					ws.cookieUsername = c.Value
   821  				} else if ows.PasswordCookie == c.Name {
   822  					ws.cookiePassword = c.Value
   823  				} else if ows.TokenCookie == c.Name {
   824  					ws.cookieToken = c.Value
   825  				}
   826  			}
   827  		}
   828  	}
   829  	return &wsUpgradeResult{conn: conn, ws: ws, kind: kind}, nil
   830  }
   831  
   832  // Returns true if the header named `name` contains a token with value `value`.
   833  func wsHeaderContains(header http.Header, name string, value string) bool {
   834  	for _, s := range header[name] {
   835  		tokens := strings.Split(s, ",")
   836  		for _, t := range tokens {
   837  			t = strings.Trim(t, " \t")
   838  			if strings.EqualFold(t, value) {
   839  				return true
   840  			}
   841  		}
   842  	}
   843  	return false
   844  }
   845  
   846  func wsPMCExtensionSupport(header http.Header, checkPMCOnly bool) (bool, bool) {
   847  	for _, extensionList := range header["Sec-Websocket-Extensions"] {
   848  		extensions := strings.Split(extensionList, ",")
   849  		for _, extension := range extensions {
   850  			extension = strings.Trim(extension, " \t")
   851  			params := strings.Split(extension, ";")
   852  			for i, p := range params {
   853  				p = strings.Trim(p, " \t")
   854  				if strings.EqualFold(p, wsPMCExtension) {
   855  					if checkPMCOnly {
   856  						return true, false
   857  					}
   858  					var snc bool
   859  					var cnc bool
   860  					for j := i + 1; j < len(params); j++ {
   861  						p = params[j]
   862  						p = strings.Trim(p, " \t")
   863  						if strings.EqualFold(p, wsPMCSrvNoCtx) {
   864  							snc = true
   865  						} else if strings.EqualFold(p, wsPMCCliNoCtx) {
   866  							cnc = true
   867  						}
   868  						if snc && cnc {
   869  							return true, true
   870  						}
   871  					}
   872  					return true, false
   873  				}
   874  			}
   875  		}
   876  	}
   877  	return false, false
   878  }
   879  
   880  // Send an HTTP error with the given `status` to the given http response writer `w`.
   881  // Return an error created based on the `reason` string.
   882  func wsReturnHTTPError(w http.ResponseWriter, r *http.Request, status int, reason string) error {
   883  	err := fmt.Errorf("%s - websocket handshake error: %s", r.RemoteAddr, reason)
   884  	w.Header().Set("Sec-Websocket-Version", "13")
   885  	http.Error(w, http.StatusText(status), status)
   886  	return err
   887  }
   888  
   889  // If the server is configured to accept any origin, then this function returns
   890  // `nil` without checking if the Origin is present and valid. This is also
   891  // the case if the request does not have the Origin header.
   892  // Otherwise, this will check that the Origin matches the same origin or
   893  // any origin in the allowed list.
   894  func (w *srvWebsocket) checkOrigin(r *http.Request) error {
   895  	w.mu.RLock()
   896  	checkSame := w.sameOrigin
   897  	listEmpty := len(w.allowedOrigins) == 0
   898  	w.mu.RUnlock()
   899  	if !checkSame && listEmpty {
   900  		return nil
   901  	}
   902  	origin := r.Header.Get("Origin")
   903  	if origin == _EMPTY_ {
   904  		origin = r.Header.Get("Sec-Websocket-Origin")
   905  	}
   906  	// If the header is not present, we will accept.
   907  	// From https://datatracker.ietf.org/doc/html/rfc6455#section-1.6
   908  	// "Naturally, when the WebSocket Protocol is used by a dedicated client
   909  	// directly (i.e., not from a web page through a web browser), the origin
   910  	// model is not useful, as the client can provide any arbitrary origin string."
   911  	if origin == _EMPTY_ {
   912  		return nil
   913  	}
   914  	u, err := url.ParseRequestURI(origin)
   915  	if err != nil {
   916  		return err
   917  	}
   918  	oh, op, err := wsGetHostAndPort(u.Scheme == "https", u.Host)
   919  	if err != nil {
   920  		return err
   921  	}
   922  	// If checking same origin, compare with the http's request's Host.
   923  	if checkSame {
   924  		rh, rp, err := wsGetHostAndPort(r.TLS != nil, r.Host)
   925  		if err != nil {
   926  			return err
   927  		}
   928  		if oh != rh || op != rp {
   929  			return errors.New("not same origin")
   930  		}
   931  		// I guess it is possible to have cases where one wants to check
   932  		// same origin, but also that the origin is in the allowed list.
   933  		// So continue with the next check.
   934  	}
   935  	if !listEmpty {
   936  		w.mu.RLock()
   937  		ao := w.allowedOrigins[oh]
   938  		w.mu.RUnlock()
   939  		if ao == nil || u.Scheme != ao.scheme || op != ao.port {
   940  			return errors.New("not in the allowed list")
   941  		}
   942  	}
   943  	return nil
   944  }
   945  
   946  func wsGetHostAndPort(tls bool, hostport string) (string, string, error) {
   947  	host, port, err := net.SplitHostPort(hostport)
   948  	if err != nil {
   949  		// If error is missing port, then use defaults based on the scheme
   950  		if ae, ok := err.(*net.AddrError); ok && strings.Contains(ae.Err, "missing port") {
   951  			err = nil
   952  			host = hostport
   953  			if tls {
   954  				port = "443"
   955  			} else {
   956  				port = "80"
   957  			}
   958  		}
   959  	}
   960  	return strings.ToLower(host), port, err
   961  }
   962  
   963  // Concatenate the key sent by the client with the GUID, then computes the SHA1 hash
   964  // and returns it as a based64 encoded string.
   965  func wsAcceptKey(key string) string {
   966  	h := sha1.New()
   967  	h.Write([]byte(key))
   968  	h.Write(wsGUID)
   969  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
   970  }
   971  
   972  func wsMakeChallengeKey() (string, error) {
   973  	p := make([]byte, 16)
   974  	if _, err := io.ReadFull(crand.Reader, p); err != nil {
   975  		return _EMPTY_, err
   976  	}
   977  	return base64.StdEncoding.EncodeToString(p), nil
   978  }
   979  
   980  // Validate the websocket related options.
   981  func validateWebsocketOptions(o *Options) error {
   982  	wo := &o.Websocket
   983  	// If no port is defined, we don't care about other options
   984  	if wo.Port == 0 {
   985  		return nil
   986  	}
   987  	// Enforce TLS... unless NoTLS is set to true.
   988  	if wo.TLSConfig == nil && !wo.NoTLS {
   989  		return errors.New("websocket requires TLS configuration")
   990  	}
   991  	// Make sure that allowed origins, if specified, can be parsed.
   992  	for _, ao := range wo.AllowedOrigins {
   993  		if _, err := url.Parse(ao); err != nil {
   994  			return fmt.Errorf("unable to parse allowed origin: %v", err)
   995  		}
   996  	}
   997  	// If there is a NoAuthUser, we need to have Users defined and
   998  	// the user to be present.
   999  	if wo.NoAuthUser != _EMPTY_ {
  1000  		if err := validateNoAuthUser(o, wo.NoAuthUser); err != nil {
  1001  			return err
  1002  		}
  1003  	}
  1004  	// Token/Username not possible if there are users/nkeys
  1005  	if len(o.Users) > 0 || len(o.Nkeys) > 0 {
  1006  		if wo.Username != _EMPTY_ {
  1007  			return fmt.Errorf("websocket authentication username not compatible with presence of users/nkeys")
  1008  		}
  1009  		if wo.Token != _EMPTY_ {
  1010  			return fmt.Errorf("websocket authentication token not compatible with presence of users/nkeys")
  1011  		}
  1012  	}
  1013  	// Using JWT requires Trusted Keys
  1014  	if wo.JWTCookie != _EMPTY_ {
  1015  		if len(o.TrustedOperators) == 0 && len(o.TrustedKeys) == 0 {
  1016  			return fmt.Errorf("trusted operators or trusted keys configuration is required for JWT authentication via cookie %q", wo.JWTCookie)
  1017  		}
  1018  	}
  1019  	if err := validatePinnedCerts(wo.TLSPinnedCerts); err != nil {
  1020  		return fmt.Errorf("websocket: %v", err)
  1021  	}
  1022  	return nil
  1023  }
  1024  
  1025  // Creates or updates the existing map
  1026  func (s *Server) wsSetOriginOptions(o *WebsocketOpts) {
  1027  	ws := &s.websocket
  1028  	ws.mu.Lock()
  1029  	defer ws.mu.Unlock()
  1030  	// Copy over the option's same origin boolean
  1031  	ws.sameOrigin = o.SameOrigin
  1032  	// Reset the map. Will help for config reload if/when we support it.
  1033  	ws.allowedOrigins = nil
  1034  	if o.AllowedOrigins == nil {
  1035  		return
  1036  	}
  1037  	for _, ao := range o.AllowedOrigins {
  1038  		// We have previously checked (during options validation) that the urls
  1039  		// are parseable, but if we get an error, report and skip.
  1040  		u, err := url.ParseRequestURI(ao)
  1041  		if err != nil {
  1042  			s.Errorf("error parsing allowed origin: %v", err)
  1043  			continue
  1044  		}
  1045  		h, p, _ := wsGetHostAndPort(u.Scheme == "https", u.Host)
  1046  		if ws.allowedOrigins == nil {
  1047  			ws.allowedOrigins = make(map[string]*allowedOrigin, len(o.AllowedOrigins))
  1048  		}
  1049  		ws.allowedOrigins[h] = &allowedOrigin{scheme: u.Scheme, port: p}
  1050  	}
  1051  }
  1052  
  1053  // Given the websocket options, we check if any auth configuration
  1054  // has been provided. If so, possibly create users/nkey users and
  1055  // store them in s.websocket.users/nkeys.
  1056  // Also update a boolean that indicates if auth is required for
  1057  // websocket clients.
  1058  // Server lock is held on entry.
  1059  func (s *Server) wsConfigAuth(opts *WebsocketOpts) {
  1060  	ws := &s.websocket
  1061  	// If any of those is specified, we consider that there is an override.
  1062  	ws.authOverride = opts.Username != _EMPTY_ || opts.Token != _EMPTY_ || opts.NoAuthUser != _EMPTY_
  1063  }
  1064  
  1065  func (s *Server) startWebsocketServer() {
  1066  	if s.isShuttingDown() {
  1067  		return
  1068  	}
  1069  
  1070  	sopts := s.getOpts()
  1071  	o := &sopts.Websocket
  1072  
  1073  	s.wsSetOriginOptions(o)
  1074  
  1075  	var hl net.Listener
  1076  	var proto string
  1077  	var err error
  1078  
  1079  	port := o.Port
  1080  	if port == -1 {
  1081  		port = 0
  1082  	}
  1083  	hp := net.JoinHostPort(o.Host, strconv.Itoa(port))
  1084  
  1085  	// We are enforcing (when validating the options) the use of TLS, but the
  1086  	// code was originally supporting both modes. The reason for TLS only is
  1087  	// that we expect users to send JWTs with bearer tokens and we want to
  1088  	// avoid the possibility of it being "intercepted".
  1089  
  1090  	s.mu.Lock()
  1091  	// Do not check o.NoTLS here. If a TLS configuration is available, use it,
  1092  	// regardless of NoTLS. If we don't have a TLS config, it means that the
  1093  	// user has configured NoTLS because otherwise the server would have failed
  1094  	// to start due to options validation.
  1095  	if o.TLSConfig != nil {
  1096  		proto = wsSchemePrefixTLS
  1097  		config := o.TLSConfig.Clone()
  1098  		config.GetConfigForClient = s.wsGetTLSConfig
  1099  		hl, err = tls.Listen("tcp", hp, config)
  1100  	} else {
  1101  		proto = wsSchemePrefix
  1102  		hl, err = s.network.ListenCause("tcp", hp, "ws")
  1103  	}
  1104  	s.websocket.listenerErr = err
  1105  	if err != nil {
  1106  		s.mu.Unlock()
  1107  		s.Fatalf("Unable to listen for websocket connections: %v", err)
  1108  		return
  1109  	}
  1110  	if port == 0 {
  1111  		o.Port = hl.Addr().(*net.TCPAddr).Port
  1112  	}
  1113  	s.Noticef("Listening for websocket clients on %s://%s:%d", proto, o.Host, o.Port)
  1114  	if proto == wsSchemePrefix {
  1115  		s.Warnf("Websocket not configured with TLS. DO NOT USE IN PRODUCTION!")
  1116  	}
  1117  
  1118  	s.websocket.tls = proto == "wss"
  1119  	s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port)
  1120  	if err != nil {
  1121  		s.Fatalf("Unable to get websocket connect URLs: %v", err)
  1122  		hl.Close()
  1123  		s.mu.Unlock()
  1124  		return
  1125  	}
  1126  	hasLeaf := sopts.LeafNode.Port != 0
  1127  	mux := http.NewServeMux()
  1128  	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
  1129  		res, err := s.wsUpgrade(w, r)
  1130  		if err != nil {
  1131  			s.Errorf(err.Error())
  1132  			return
  1133  		}
  1134  		switch res.kind {
  1135  		case CLIENT:
  1136  			s.createWSClient(res.conn, res.ws)
  1137  		case MQTT:
  1138  			s.createMQTTClient(res.conn, res.ws)
  1139  		case LEAF:
  1140  			if !hasLeaf {
  1141  				s.Errorf("Not configured to accept leaf node connections")
  1142  				// Silently close for now. If we want to send an error back, we would
  1143  				// need to create the leafnode client anyway, so that is is handling websocket
  1144  				// frames, then send the error to the remote.
  1145  				res.conn.Close()
  1146  				return
  1147  			}
  1148  			s.createLeafNode(res.conn, nil, nil, res.ws)
  1149  		}
  1150  	})
  1151  	hs := &http.Server{
  1152  		Addr:        hp,
  1153  		Handler:     mux,
  1154  		ReadTimeout: o.HandshakeTimeout,
  1155  		ErrorLog:    log.New(&captureHTTPServerLog{s, "websocket: "}, _EMPTY_, 0),
  1156  	}
  1157  	s.websocket.server = hs
  1158  	s.websocket.listener = hl
  1159  	go func() {
  1160  		if err := hs.Serve(hl); err != http.ErrServerClosed {
  1161  			s.Fatalf("websocket listener error: %v", err)
  1162  		}
  1163  		if s.isLameDuckMode() {
  1164  			// Signal that we are not accepting new clients
  1165  			s.ldmCh <- true
  1166  			// Now wait for the Shutdown...
  1167  			<-s.quitCh
  1168  			return
  1169  		}
  1170  		s.done <- true
  1171  	}()
  1172  	s.mu.Unlock()
  1173  }
  1174  
  1175  // The TLS configuration is passed to the listener when the websocket
  1176  // "server" is setup. That prevents TLS configuration updates on reload
  1177  // from being used. By setting this function in tls.Config.GetConfigForClient
  1178  // we instruct the TLS handshake to ask for the tls configuration to be
  1179  // used for a specific client. We don't care which client, we always use
  1180  // the same TLS configuration.
  1181  func (s *Server) wsGetTLSConfig(_ *tls.ClientHelloInfo) (*tls.Config, error) {
  1182  	opts := s.getOpts()
  1183  	return opts.Websocket.TLSConfig, nil
  1184  }
  1185  
  1186  // This is similar to createClient() but has some modifications
  1187  // specific to handle websocket clients.
  1188  // The comments have been kept to minimum to reduce code size.
  1189  // Check createClient() for more details.
  1190  func (s *Server) createWSClient(conn net.Conn, ws *websocket) *client {
  1191  	opts := s.getOpts()
  1192  
  1193  	maxPay := int32(opts.MaxPayload)
  1194  	maxSubs := int32(opts.MaxSubs)
  1195  	if maxSubs == 0 {
  1196  		maxSubs = -1
  1197  	}
  1198  	now := time.Now().UTC()
  1199  
  1200  	c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws}
  1201  
  1202  	c.registerWithAccount(s.globalAccount())
  1203  
  1204  	var info Info
  1205  	var authRequired bool
  1206  
  1207  	s.mu.Lock()
  1208  	info = s.copyInfo()
  1209  	// Check auth, override if applicable.
  1210  	if !info.AuthRequired {
  1211  		// Set info.AuthRequired since this is what is sent to the client.
  1212  		info.AuthRequired = s.websocket.authOverride
  1213  	}
  1214  	if s.nonceRequired() {
  1215  		var raw [nonceLen]byte
  1216  		nonce := raw[:]
  1217  		s.generateNonce(nonce)
  1218  		info.Nonce = string(nonce)
  1219  	}
  1220  	c.nonce = []byte(info.Nonce)
  1221  	authRequired = info.AuthRequired
  1222  
  1223  	s.totalClients++
  1224  	s.mu.Unlock()
  1225  
  1226  	c.mu.Lock()
  1227  	if authRequired {
  1228  		c.flags.set(expectConnect)
  1229  	}
  1230  	c.initClient()
  1231  	c.Debugf("Client connection created")
  1232  	c.sendProtoNow(c.generateClientInfoJSON(info))
  1233  	c.mu.Unlock()
  1234  
  1235  	s.mu.Lock()
  1236  	if !s.isRunning() || s.ldm {
  1237  		if s.isShuttingDown() {
  1238  			conn.Close()
  1239  		}
  1240  		s.mu.Unlock()
  1241  		return c
  1242  	}
  1243  
  1244  	if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn {
  1245  		s.mu.Unlock()
  1246  		c.maxConnExceeded()
  1247  		return nil
  1248  	}
  1249  	s.clients[c.cid] = c
  1250  	s.mu.Unlock()
  1251  
  1252  	c.mu.Lock()
  1253  	// Websocket clients do TLS in the websocket http server.
  1254  	// So no TLS initiation here...
  1255  	if _, ok := conn.(*tls.Conn); ok {
  1256  		c.flags.set(handshakeComplete)
  1257  	}
  1258  
  1259  	if c.isClosed() {
  1260  		c.mu.Unlock()
  1261  		c.closeConnection(WriteError)
  1262  		return nil
  1263  	}
  1264  
  1265  	if authRequired {
  1266  		timeout := opts.AuthTimeout
  1267  		// Possibly override with Websocket specific value.
  1268  		if opts.Websocket.AuthTimeout != 0 {
  1269  			timeout = opts.Websocket.AuthTimeout
  1270  		}
  1271  		c.setAuthTimer(secondsToDuration(timeout))
  1272  	}
  1273  
  1274  	c.setPingTimer()
  1275  
  1276  	s.startGoRoutine(func() { c.readLoop(nil) })
  1277  	s.startGoRoutine(func() { c.writeLoop() })
  1278  
  1279  	c.mu.Unlock()
  1280  
  1281  	return c
  1282  }
  1283  
  1284  func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
  1285  	nb := c.out.nb
  1286  	var mfs int
  1287  	var usz int
  1288  	if c.ws.browser {
  1289  		mfs = wsFrameSizeForBrowsers
  1290  	}
  1291  	mask := c.ws.maskwrite
  1292  	// Start with possible already framed buffers (that we could have
  1293  	// got from partials or control messages such as ws pings or pongs).
  1294  	bufs := c.ws.frames
  1295  	compress := c.ws.compress
  1296  	if compress && len(nb) > 0 {
  1297  		// First, make sure we don't compress for very small cumulative buffers.
  1298  		for _, b := range nb {
  1299  			usz += len(b)
  1300  		}
  1301  		if usz <= wsCompressThreshold {
  1302  			compress = false
  1303  		}
  1304  	}
  1305  	if compress && len(nb) > 0 {
  1306  		// Overwrite mfs if this connection does not support fragmented compressed frames.
  1307  		if mfs > 0 && c.ws.nocompfrag {
  1308  			mfs = 0
  1309  		}
  1310  		buf := bytes.NewBuffer(nbPoolGet(usz))
  1311  		cp := c.ws.compressor
  1312  		if cp == nil {
  1313  			c.ws.compressor, _ = flate.NewWriter(buf, flate.BestSpeed)
  1314  			cp = c.ws.compressor
  1315  		} else {
  1316  			cp.Reset(buf)
  1317  		}
  1318  		var csz int
  1319  		for _, b := range nb {
  1320  			cp.Write(b)
  1321  			nbPoolPut(b) // No longer needed as contents written to compressor.
  1322  		}
  1323  		if err := cp.Flush(); err != nil {
  1324  			c.Errorf("Error during compression: %v", err)
  1325  			c.markConnAsClosed(WriteError)
  1326  			return nil, 0
  1327  		}
  1328  		b := buf.Bytes()
  1329  		p := b[:len(b)-4]
  1330  		if mfs > 0 && len(p) > mfs {
  1331  			for first, final := true, false; len(p) > 0; first = false {
  1332  				lp := len(p)
  1333  				if lp > mfs {
  1334  					lp = mfs
  1335  				} else {
  1336  					final = true
  1337  				}
  1338  				// Only the first frame should be marked as compressed, so pass
  1339  				// `first` for the compressed boolean.
  1340  				fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize]
  1341  				n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp)
  1342  				if mask {
  1343  					wsMaskBuf(key, p[:lp])
  1344  				}
  1345  				bufs = append(bufs, fh[:n], p[:lp])
  1346  				csz += n + lp
  1347  				p = p[lp:]
  1348  			}
  1349  		} else {
  1350  			ol := len(p)
  1351  			h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, ol)
  1352  			if mask {
  1353  				wsMaskBuf(key, p)
  1354  			}
  1355  			if ol > 0 {
  1356  				bufs = append(bufs, h, p)
  1357  			}
  1358  			csz = len(h) + ol
  1359  		}
  1360  		// Make sure that the compressor no longer holds a reference to
  1361  		// the bytes.Buffer, so that the underlying memory gets cleaned
  1362  		// up after flushOutbound/flushAndClose. For this to be safe, we
  1363  		// always cp.Reset(...) before reusing the compressor again.
  1364  		cp.Reset(nil)
  1365  		// Add to pb the compressed data size (including headers), but
  1366  		// remove the original uncompressed data size that was added
  1367  		// during the queueing.
  1368  		c.out.pb += int64(csz) - int64(usz)
  1369  		c.ws.fs += int64(csz)
  1370  	} else if len(nb) > 0 {
  1371  		var total int
  1372  		if mfs > 0 {
  1373  			// We are limiting the frame size.
  1374  			startFrame := func() int {
  1375  				bufs = append(bufs, nbPoolGet(wsMaxFrameHeaderSize))
  1376  				return len(bufs) - 1
  1377  			}
  1378  			endFrame := func(idx, size int) {
  1379  				bufs[idx] = bufs[idx][:wsMaxFrameHeaderSize]
  1380  				n, key := wsFillFrameHeader(bufs[idx], mask, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, wsBinaryMessage, size)
  1381  				bufs[idx] = bufs[idx][:n]
  1382  				c.out.pb += int64(n)
  1383  				c.ws.fs += int64(n + size)
  1384  				if mask {
  1385  					wsMaskBufs(key, bufs[idx+1:])
  1386  				}
  1387  			}
  1388  
  1389  			fhIdx := startFrame()
  1390  			for i := 0; i < len(nb); i++ {
  1391  				b := nb[i]
  1392  				if total+len(b) <= mfs {
  1393  					buf := nbPoolGet(len(b))
  1394  					bufs = append(bufs, append(buf, b...))
  1395  					total += len(b)
  1396  					nbPoolPut(nb[i])
  1397  					continue
  1398  				}
  1399  				for len(b) > 0 {
  1400  					endStart := total != 0
  1401  					if endStart {
  1402  						endFrame(fhIdx, total)
  1403  					}
  1404  					total = len(b)
  1405  					if total >= mfs {
  1406  						total = mfs
  1407  					}
  1408  					if endStart {
  1409  						fhIdx = startFrame()
  1410  					}
  1411  					buf := nbPoolGet(total)
  1412  					bufs = append(bufs, append(buf, b[:total]...))
  1413  					b = b[total:]
  1414  				}
  1415  				nbPoolPut(nb[i]) // No longer needed as copied into smaller frames.
  1416  			}
  1417  			if total > 0 {
  1418  				endFrame(fhIdx, total)
  1419  			}
  1420  		} else {
  1421  			// If there is no limit on the frame size, create a single frame for
  1422  			// all pending buffers.
  1423  			for _, b := range nb {
  1424  				total += len(b)
  1425  			}
  1426  			wsfh, key := wsCreateFrameHeader(mask, false, wsBinaryMessage, total)
  1427  			c.out.pb += int64(len(wsfh))
  1428  			bufs = append(bufs, wsfh)
  1429  			idx := len(bufs)
  1430  			bufs = append(bufs, nb...)
  1431  			if mask {
  1432  				wsMaskBufs(key, bufs[idx:])
  1433  			}
  1434  			c.ws.fs += int64(len(wsfh) + total)
  1435  		}
  1436  	}
  1437  	if len(c.ws.closeMsg) > 0 {
  1438  		bufs = append(bufs, c.ws.closeMsg)
  1439  		c.ws.fs += int64(len(c.ws.closeMsg))
  1440  		c.ws.closeMsg = nil
  1441  	}
  1442  	c.ws.frames = nil
  1443  	return bufs, c.ws.fs
  1444  }
  1445  
  1446  func isWSURL(u *url.URL) bool {
  1447  	return strings.HasPrefix(strings.ToLower(u.Scheme), wsSchemePrefix)
  1448  }
  1449  
  1450  func isWSSURL(u *url.URL) bool {
  1451  	return strings.HasPrefix(strings.ToLower(u.Scheme), wsSchemePrefixTLS)
  1452  }