github.com/nats-io/nats-server/v2@v2.11.0-preview.2/server/websocket.go (about)

     1  // Copyright 2020-2024 The NATS Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package server
    15  
    16  import (
    17  	"bytes"
    18  	crand "crypto/rand"
    19  	"crypto/sha1"
    20  	"crypto/tls"
    21  	"encoding/base64"
    22  	"encoding/binary"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"log"
    27  	mrand "math/rand"
    28  	"net"
    29  	"net/http"
    30  	"net/url"
    31  	"strconv"
    32  	"strings"
    33  	"sync"
    34  	"time"
    35  	"unicode/utf8"
    36  
    37  	"github.com/klauspost/compress/flate"
    38  )
    39  
    40  type wsOpCode int
    41  
    42  const (
    43  	// From https://tools.ietf.org/html/rfc6455#section-5.2
    44  	wsTextMessage   = wsOpCode(1)
    45  	wsBinaryMessage = wsOpCode(2)
    46  	wsCloseMessage  = wsOpCode(8)
    47  	wsPingMessage   = wsOpCode(9)
    48  	wsPongMessage   = wsOpCode(10)
    49  
    50  	wsFinalBit = 1 << 7
    51  	wsRsv1Bit  = 1 << 6 // Used for compression, from https://tools.ietf.org/html/rfc7692#section-6
    52  	wsRsv2Bit  = 1 << 5
    53  	wsRsv3Bit  = 1 << 4
    54  
    55  	wsMaskBit = 1 << 7
    56  
    57  	wsContinuationFrame     = 0
    58  	wsMaxFrameHeaderSize    = 14 // Since LeafNode may need to behave as a client
    59  	wsMaxControlPayloadSize = 125
    60  	wsFrameSizeForBrowsers  = 4096 // From experiment, webrowsers behave better with limited frame size
    61  	wsCompressThreshold     = 64   // Don't compress for small buffer(s)
    62  	wsCloseSatusSize        = 2
    63  
    64  	// From https://tools.ietf.org/html/rfc6455#section-11.7
    65  	wsCloseStatusNormalClosure      = 1000
    66  	wsCloseStatusGoingAway          = 1001
    67  	wsCloseStatusProtocolError      = 1002
    68  	wsCloseStatusUnsupportedData    = 1003
    69  	wsCloseStatusNoStatusReceived   = 1005
    70  	wsCloseStatusAbnormalClosure    = 1006
    71  	wsCloseStatusInvalidPayloadData = 1007
    72  	wsCloseStatusPolicyViolation    = 1008
    73  	wsCloseStatusMessageTooBig      = 1009
    74  	wsCloseStatusInternalSrvError   = 1011
    75  	wsCloseStatusTLSHandshake       = 1015
    76  
    77  	wsFirstFrame        = true
    78  	wsContFrame         = false
    79  	wsFinalFrame        = true
    80  	wsUncompressedFrame = false
    81  
    82  	wsSchemePrefix    = "ws"
    83  	wsSchemePrefixTLS = "wss"
    84  
    85  	wsNoMaskingHeader       = "Nats-No-Masking"
    86  	wsNoMaskingValue        = "true"
    87  	wsXForwardedForHeader   = "X-Forwarded-For"
    88  	wsNoMaskingFullResponse = wsNoMaskingHeader + ": " + wsNoMaskingValue + CR_LF
    89  	wsPMCExtension          = "permessage-deflate" // per-message compression
    90  	wsPMCSrvNoCtx           = "server_no_context_takeover"
    91  	wsPMCCliNoCtx           = "client_no_context_takeover"
    92  	wsPMCReqHeaderValue     = wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx
    93  	wsPMCFullResponse       = "Sec-WebSocket-Extensions: " + wsPMCExtension + "; " + wsPMCSrvNoCtx + "; " + wsPMCCliNoCtx + _CRLF_
    94  	wsSecProto              = "Sec-Websocket-Protocol"
    95  	wsMQTTSecProtoVal       = "mqtt"
    96  	wsMQTTSecProto          = wsSecProto + ": " + wsMQTTSecProtoVal + CR_LF
    97  )
    98  
    99  var decompressorPool sync.Pool
   100  var compressLastBlock = []byte{0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff}
   101  
   102  // From https://tools.ietf.org/html/rfc6455#section-1.3
   103  var wsGUID = []byte("258EAFA5-E914-47DA-95CA-C5AB0DC85B11")
   104  
   105  // Test can enable this so that server does not support "no-masking" requests.
   106  var wsTestRejectNoMasking = false
   107  
   108  type websocket struct {
   109  	frames         net.Buffers
   110  	fs             int64
   111  	closeMsg       []byte
   112  	compress       bool
   113  	closeSent      bool
   114  	browser        bool
   115  	nocompfrag     bool // No fragment for compressed frames
   116  	maskread       bool
   117  	maskwrite      bool
   118  	compressor     *flate.Writer
   119  	cookieJwt      string
   120  	cookieUsername string
   121  	cookiePassword string
   122  	cookieToken    string
   123  	clientIP       string
   124  }
   125  
   126  type srvWebsocket struct {
   127  	mu             sync.RWMutex
   128  	server         *http.Server
   129  	listener       net.Listener
   130  	listenerErr    error
   131  	tls            bool
   132  	allowedOrigins map[string]*allowedOrigin // host will be the key
   133  	sameOrigin     bool
   134  	connectURLs    []string
   135  	connectURLsMap refCountedUrlSet
   136  	authOverride   bool   // indicate if there is auth override in websocket config
   137  	rawHeaders     string // raw headers to be used in the upgrade response.
   138  }
   139  
   140  type allowedOrigin struct {
   141  	scheme string
   142  	port   string
   143  }
   144  
   145  type wsUpgradeResult struct {
   146  	conn net.Conn
   147  	ws   *websocket
   148  	kind int
   149  }
   150  
   151  type wsReadInfo struct {
   152  	rem   int
   153  	fs    bool
   154  	ff    bool
   155  	fc    bool
   156  	mask  bool // Incoming leafnode connections may not have masking.
   157  	mkpos byte
   158  	mkey  [4]byte
   159  	cbufs [][]byte
   160  	coff  int
   161  }
   162  
   163  func (r *wsReadInfo) init() {
   164  	r.fs, r.ff = true, true
   165  }
   166  
   167  // Returns a slice containing `needed` bytes from the given buffer `buf`
   168  // starting at position `pos`, and possibly read from the given reader `r`.
   169  // When bytes are present in `buf`, the `pos` is incremented by the number
   170  // of bytes found up to `needed` and the new position is returned. If not
   171  // enough bytes are found, the bytes found in `buf` are copied to the returned
   172  // slice and the remaning bytes are read from `r`.
   173  func wsGet(r io.Reader, buf []byte, pos, needed int) ([]byte, int, error) {
   174  	avail := len(buf) - pos
   175  	if avail >= needed {
   176  		return buf[pos : pos+needed], pos + needed, nil
   177  	}
   178  	b := make([]byte, needed)
   179  	start := copy(b, buf[pos:])
   180  	for start != needed {
   181  		n, err := r.Read(b[start:cap(b)])
   182  		if err != nil {
   183  			return nil, 0, err
   184  		}
   185  		start += n
   186  	}
   187  	return b, pos + avail, nil
   188  }
   189  
   190  // Returns true if this connection is from a Websocket client.
   191  // Lock held on entry.
   192  func (c *client) isWebsocket() bool {
   193  	return c.ws != nil
   194  }
   195  
   196  // Returns a slice of byte slices corresponding to payload of websocket frames.
   197  // The byte slice `buf` is filled with bytes from the connection's read loop.
   198  // This function will decode the frame headers and unmask the payload(s).
   199  // It is possible that the returned slices point to the given `buf` slice, so
   200  // `buf` should not be overwritten until the returned slices have been parsed.
   201  //
   202  // Client lock MUST NOT be held on entry.
   203  func (c *client) wsRead(r *wsReadInfo, ior io.Reader, buf []byte) ([][]byte, error) {
   204  	var (
   205  		bufs   [][]byte
   206  		tmpBuf []byte
   207  		err    error
   208  		pos    int
   209  		max    = len(buf)
   210  	)
   211  	for pos != max {
   212  		if r.fs {
   213  			b0 := buf[pos]
   214  			frameType := wsOpCode(b0 & 0xF)
   215  			final := b0&wsFinalBit != 0
   216  			compressed := b0&wsRsv1Bit != 0
   217  			pos++
   218  
   219  			tmpBuf, pos, err = wsGet(ior, buf, pos, 1)
   220  			if err != nil {
   221  				return bufs, err
   222  			}
   223  			b1 := tmpBuf[0]
   224  
   225  			// Clients MUST set the mask bit. If not set, reject.
   226  			// However, LEAF by default will not have masking, unless they are forced to, by configuration.
   227  			if r.mask && b1&wsMaskBit == 0 {
   228  				return bufs, c.wsHandleProtocolError("mask bit missing")
   229  			}
   230  
   231  			// Store size in case it is < 125
   232  			r.rem = int(b1 & 0x7F)
   233  
   234  			switch frameType {
   235  			case wsPingMessage, wsPongMessage, wsCloseMessage:
   236  				if r.rem > wsMaxControlPayloadSize {
   237  					return bufs, c.wsHandleProtocolError(
   238  						fmt.Sprintf("control frame length bigger than maximum allowed of %v bytes",
   239  							wsMaxControlPayloadSize))
   240  				}
   241  				if !final {
   242  					return bufs, c.wsHandleProtocolError("control frame does not have final bit set")
   243  				}
   244  			case wsTextMessage, wsBinaryMessage:
   245  				if !r.ff {
   246  					return bufs, c.wsHandleProtocolError("new message started before final frame for previous message was received")
   247  				}
   248  				r.ff = final
   249  				r.fc = compressed
   250  			case wsContinuationFrame:
   251  				// Compressed bit must be only set in the first frame
   252  				if r.ff || compressed {
   253  					return bufs, c.wsHandleProtocolError("invalid continuation frame")
   254  				}
   255  				r.ff = final
   256  			default:
   257  				return bufs, c.wsHandleProtocolError(fmt.Sprintf("unknown opcode %v", frameType))
   258  			}
   259  
   260  			switch r.rem {
   261  			case 126:
   262  				tmpBuf, pos, err = wsGet(ior, buf, pos, 2)
   263  				if err != nil {
   264  					return bufs, err
   265  				}
   266  				r.rem = int(binary.BigEndian.Uint16(tmpBuf))
   267  			case 127:
   268  				tmpBuf, pos, err = wsGet(ior, buf, pos, 8)
   269  				if err != nil {
   270  					return bufs, err
   271  				}
   272  				r.rem = int(binary.BigEndian.Uint64(tmpBuf))
   273  			}
   274  
   275  			if r.mask {
   276  				// Read masking key
   277  				tmpBuf, pos, err = wsGet(ior, buf, pos, 4)
   278  				if err != nil {
   279  					return bufs, err
   280  				}
   281  				copy(r.mkey[:], tmpBuf)
   282  				r.mkpos = 0
   283  			}
   284  
   285  			// Handle control messages in place...
   286  			if wsIsControlFrame(frameType) {
   287  				pos, err = c.wsHandleControlFrame(r, frameType, ior, buf, pos)
   288  				if err != nil {
   289  					return bufs, err
   290  				}
   291  				continue
   292  			}
   293  
   294  			// Done with the frame header
   295  			r.fs = false
   296  		}
   297  		if pos < max {
   298  			var b []byte
   299  			var n int
   300  
   301  			n = r.rem
   302  			if pos+n > max {
   303  				n = max - pos
   304  			}
   305  			b = buf[pos : pos+n]
   306  			pos += n
   307  			r.rem -= n
   308  			// If needed, unmask the buffer
   309  			if r.mask {
   310  				r.unmask(b)
   311  			}
   312  			addToBufs := true
   313  			// Handle compressed message
   314  			if r.fc {
   315  				// Assume that we may have continuation frames or not the full payload.
   316  				addToBufs = false
   317  				// Make a copy of the buffer before adding it to the list
   318  				// of compressed fragments.
   319  				r.cbufs = append(r.cbufs, append([]byte(nil), b...))
   320  				// When we have the final frame and we have read the full payload,
   321  				// we can decompress it.
   322  				if r.ff && r.rem == 0 {
   323  					b, err = r.decompress()
   324  					if err != nil {
   325  						return bufs, err
   326  					}
   327  					r.fc = false
   328  					// Now we can add to `bufs`
   329  					addToBufs = true
   330  				}
   331  			}
   332  			// For non compressed frames, or when we have decompressed the
   333  			// whole message.
   334  			if addToBufs {
   335  				bufs = append(bufs, b)
   336  			}
   337  			// If payload has been fully read, then indicate that next
   338  			// is the start of a frame.
   339  			if r.rem == 0 {
   340  				r.fs = true
   341  			}
   342  		}
   343  	}
   344  	return bufs, nil
   345  }
   346  
   347  func (r *wsReadInfo) Read(dst []byte) (int, error) {
   348  	if len(dst) == 0 {
   349  		return 0, nil
   350  	}
   351  	if len(r.cbufs) == 0 {
   352  		return 0, io.EOF
   353  	}
   354  	copied := 0
   355  	rem := len(dst)
   356  	for buf := r.cbufs[0]; buf != nil && rem > 0; {
   357  		n := len(buf[r.coff:])
   358  		if n > rem {
   359  			n = rem
   360  		}
   361  		copy(dst[copied:], buf[r.coff:r.coff+n])
   362  		copied += n
   363  		rem -= n
   364  		r.coff += n
   365  		buf = r.nextCBuf()
   366  	}
   367  	return copied, nil
   368  }
   369  
   370  func (r *wsReadInfo) nextCBuf() []byte {
   371  	// We still have remaining data in the first buffer
   372  	if r.coff != len(r.cbufs[0]) {
   373  		return r.cbufs[0]
   374  	}
   375  	// We read the full first buffer. Reset offset.
   376  	r.coff = 0
   377  	// We were at the last buffer, so we are done.
   378  	if len(r.cbufs) == 1 {
   379  		r.cbufs = nil
   380  		return nil
   381  	}
   382  	// Here we move to the next buffer.
   383  	r.cbufs = r.cbufs[1:]
   384  	return r.cbufs[0]
   385  }
   386  
   387  func (r *wsReadInfo) ReadByte() (byte, error) {
   388  	if len(r.cbufs) == 0 {
   389  		return 0, io.EOF
   390  	}
   391  	b := r.cbufs[0][r.coff]
   392  	r.coff++
   393  	r.nextCBuf()
   394  	return b, nil
   395  }
   396  
   397  func (r *wsReadInfo) decompress() ([]byte, error) {
   398  	r.coff = 0
   399  	// As per https://tools.ietf.org/html/rfc7692#section-7.2.2
   400  	// add 0x00, 0x00, 0xff, 0xff and then a final block so that flate reader
   401  	// does not report unexpected EOF.
   402  	r.cbufs = append(r.cbufs, compressLastBlock)
   403  	// Get a decompressor from the pool and bind it to this object (wsReadInfo)
   404  	// that provides Read() and ReadByte() APIs that will consume the compressed
   405  	// buffers (r.cbufs).
   406  	d, _ := decompressorPool.Get().(io.ReadCloser)
   407  	if d == nil {
   408  		d = flate.NewReader(r)
   409  	} else {
   410  		d.(flate.Resetter).Reset(r, nil)
   411  	}
   412  	// This will do the decompression.
   413  	b, err := io.ReadAll(d)
   414  	decompressorPool.Put(d)
   415  	// Now reset the compressed buffers list.
   416  	r.cbufs = nil
   417  	return b, err
   418  }
   419  
   420  // Handles the PING, PONG and CLOSE websocket control frames.
   421  //
   422  // Client lock MUST NOT be held on entry.
   423  func (c *client) wsHandleControlFrame(r *wsReadInfo, frameType wsOpCode, nc io.Reader, buf []byte, pos int) (int, error) {
   424  	var payload []byte
   425  	var err error
   426  
   427  	if r.rem > 0 {
   428  		payload, pos, err = wsGet(nc, buf, pos, r.rem)
   429  		if err != nil {
   430  			return pos, err
   431  		}
   432  		if r.mask {
   433  			r.unmask(payload)
   434  		}
   435  		r.rem = 0
   436  	}
   437  	switch frameType {
   438  	case wsCloseMessage:
   439  		status := wsCloseStatusNoStatusReceived
   440  		var body string
   441  		lp := len(payload)
   442  		// If there is a payload, the status is represented as a 2-byte
   443  		// unsigned integer (in network byte order). Then, there may be an
   444  		// optional body.
   445  		hasStatus, hasBody := lp >= wsCloseSatusSize, lp > wsCloseSatusSize
   446  		if hasStatus {
   447  			// Decode the status
   448  			status = int(binary.BigEndian.Uint16(payload[:wsCloseSatusSize]))
   449  			// Now if there is a body, capture it and make sure this is a valid UTF-8.
   450  			if hasBody {
   451  				body = string(payload[wsCloseSatusSize:])
   452  				if !utf8.ValidString(body) {
   453  					// https://tools.ietf.org/html/rfc6455#section-5.5.1
   454  					// If body is present, it must be a valid utf8
   455  					status = wsCloseStatusInvalidPayloadData
   456  					body = "invalid utf8 body in close frame"
   457  				}
   458  			}
   459  		}
   460  		clm := wsCreateCloseMessage(status, body)
   461  		c.wsEnqueueControlMessage(wsCloseMessage, clm)
   462  		nbPoolPut(clm) // wsEnqueueControlMessage has taken a copy.
   463  		// Return io.EOF so that readLoop will close the connection as ClientClosed
   464  		// after processing pending buffers.
   465  		return pos, io.EOF
   466  	case wsPingMessage:
   467  		c.wsEnqueueControlMessage(wsPongMessage, payload)
   468  	case wsPongMessage:
   469  		// Nothing to do..
   470  	}
   471  	return pos, nil
   472  }
   473  
   474  // Unmask the given slice.
   475  func (r *wsReadInfo) unmask(buf []byte) {
   476  	p := int(r.mkpos)
   477  	if len(buf) < 16 {
   478  		for i := 0; i < len(buf); i++ {
   479  			buf[i] ^= r.mkey[p&3]
   480  			p++
   481  		}
   482  		r.mkpos = byte(p & 3)
   483  		return
   484  	}
   485  	var k [8]byte
   486  	for i := 0; i < 8; i++ {
   487  		k[i] = r.mkey[(p+i)&3]
   488  	}
   489  	km := binary.BigEndian.Uint64(k[:])
   490  	n := (len(buf) / 8) * 8
   491  	for i := 0; i < n; i += 8 {
   492  		tmp := binary.BigEndian.Uint64(buf[i : i+8])
   493  		tmp ^= km
   494  		binary.BigEndian.PutUint64(buf[i:], tmp)
   495  	}
   496  	buf = buf[n:]
   497  	for i := 0; i < len(buf); i++ {
   498  		buf[i] ^= r.mkey[p&3]
   499  		p++
   500  	}
   501  	r.mkpos = byte(p & 3)
   502  }
   503  
   504  // Returns true if the op code corresponds to a control frame.
   505  func wsIsControlFrame(frameType wsOpCode) bool {
   506  	return frameType >= wsCloseMessage
   507  }
   508  
   509  // Create the frame header.
   510  // Encodes the frame type and optional compression flag, and the size of the payload.
   511  func wsCreateFrameHeader(useMasking, compressed bool, frameType wsOpCode, l int) ([]byte, []byte) {
   512  	fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize]
   513  	n, key := wsFillFrameHeader(fh, useMasking, wsFirstFrame, wsFinalFrame, compressed, frameType, l)
   514  	return fh[:n], key
   515  }
   516  
   517  func wsFillFrameHeader(fh []byte, useMasking, first, final, compressed bool, frameType wsOpCode, l int) (int, []byte) {
   518  	var n int
   519  	var b byte
   520  	if first {
   521  		b = byte(frameType)
   522  	}
   523  	if final {
   524  		b |= wsFinalBit
   525  	}
   526  	if compressed {
   527  		b |= wsRsv1Bit
   528  	}
   529  	b1 := byte(0)
   530  	if useMasking {
   531  		b1 |= wsMaskBit
   532  	}
   533  	switch {
   534  	case l <= 125:
   535  		n = 2
   536  		fh[0] = b
   537  		fh[1] = b1 | byte(l)
   538  	case l < 65536:
   539  		n = 4
   540  		fh[0] = b
   541  		fh[1] = b1 | 126
   542  		binary.BigEndian.PutUint16(fh[2:], uint16(l))
   543  	default:
   544  		n = 10
   545  		fh[0] = b
   546  		fh[1] = b1 | 127
   547  		binary.BigEndian.PutUint64(fh[2:], uint64(l))
   548  	}
   549  	var key []byte
   550  	if useMasking {
   551  		var keyBuf [4]byte
   552  		if _, err := io.ReadFull(crand.Reader, keyBuf[:4]); err != nil {
   553  			kv := mrand.Int31()
   554  			binary.LittleEndian.PutUint32(keyBuf[:4], uint32(kv))
   555  		}
   556  		copy(fh[n:], keyBuf[:4])
   557  		key = fh[n : n+4]
   558  		n += 4
   559  	}
   560  	return n, key
   561  }
   562  
   563  // Invokes wsEnqueueControlMessageLocked under client lock.
   564  //
   565  // Client lock MUST NOT be held on entry
   566  func (c *client) wsEnqueueControlMessage(controlMsg wsOpCode, payload []byte) {
   567  	c.mu.Lock()
   568  	c.wsEnqueueControlMessageLocked(controlMsg, payload)
   569  	c.mu.Unlock()
   570  }
   571  
   572  // Mask the buffer with the given key
   573  func wsMaskBuf(key, buf []byte) {
   574  	for i := 0; i < len(buf); i++ {
   575  		buf[i] ^= key[i&3]
   576  	}
   577  }
   578  
   579  // Mask the buffers, as if they were contiguous, with the given key
   580  func wsMaskBufs(key []byte, bufs [][]byte) {
   581  	pos := 0
   582  	for i := 0; i < len(bufs); i++ {
   583  		buf := bufs[i]
   584  		for j := 0; j < len(buf); j++ {
   585  			buf[j] ^= key[pos&3]
   586  			pos++
   587  		}
   588  	}
   589  }
   590  
   591  // Enqueues a websocket control message.
   592  // If the control message is a wsCloseMessage, then marks this client
   593  // has having sent the close message (since only one should be sent).
   594  // This will prevent the generic closeConnection() to enqueue one.
   595  //
   596  // Client lock held on entry.
   597  func (c *client) wsEnqueueControlMessageLocked(controlMsg wsOpCode, payload []byte) {
   598  	// Control messages are never compressed and their size will be
   599  	// less than wsMaxControlPayloadSize, which means the frame header
   600  	// will be only 2 or 6 bytes.
   601  	useMasking := c.ws.maskwrite
   602  	sz := 2
   603  	if useMasking {
   604  		sz += 4
   605  	}
   606  	cm := nbPoolGet(sz + len(payload))
   607  	cm = cm[:cap(cm)]
   608  	n, key := wsFillFrameHeader(cm, useMasking, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, controlMsg, len(payload))
   609  	cm = cm[:n]
   610  	// Note that payload is optional.
   611  	if len(payload) > 0 {
   612  		cm = append(cm, payload...)
   613  		if useMasking {
   614  			wsMaskBuf(key, cm[n:])
   615  		}
   616  	}
   617  	c.out.pb += int64(len(cm))
   618  	if controlMsg == wsCloseMessage {
   619  		// We can't add the close message to the frames buffers
   620  		// now. It will be done on a flushOutbound() when there
   621  		// are no more pending buffers to send.
   622  		c.ws.closeSent = true
   623  		c.ws.closeMsg = cm
   624  	} else {
   625  		c.ws.frames = append(c.ws.frames, cm)
   626  		c.ws.fs += int64(len(cm))
   627  	}
   628  	c.flushSignal()
   629  }
   630  
   631  // Enqueues a websocket close message with a status mapped from the given `reason`.
   632  //
   633  // Client lock held on entry
   634  func (c *client) wsEnqueueCloseMessage(reason ClosedState) {
   635  	var status int
   636  	switch reason {
   637  	case ClientClosed:
   638  		status = wsCloseStatusNormalClosure
   639  	case AuthenticationTimeout, AuthenticationViolation, SlowConsumerPendingBytes, SlowConsumerWriteDeadline,
   640  		MaxAccountConnectionsExceeded, MaxConnectionsExceeded, MaxControlLineExceeded, MaxSubscriptionsExceeded,
   641  		MissingAccount, AuthenticationExpired, Revocation:
   642  		status = wsCloseStatusPolicyViolation
   643  	case TLSHandshakeError:
   644  		status = wsCloseStatusTLSHandshake
   645  	case ParseError, ProtocolViolation, BadClientProtocolVersion:
   646  		status = wsCloseStatusProtocolError
   647  	case MaxPayloadExceeded:
   648  		status = wsCloseStatusMessageTooBig
   649  	case ServerShutdown:
   650  		status = wsCloseStatusGoingAway
   651  	case WriteError, ReadError, StaleConnection:
   652  		status = wsCloseStatusAbnormalClosure
   653  	default:
   654  		status = wsCloseStatusInternalSrvError
   655  	}
   656  	body := wsCreateCloseMessage(status, reason.String())
   657  	c.wsEnqueueControlMessageLocked(wsCloseMessage, body)
   658  	nbPoolPut(body) // wsEnqueueControlMessageLocked has taken a copy.
   659  }
   660  
   661  // Create and then enqueue a close message with a protocol error and the
   662  // given message. This is invoked when parsing websocket frames.
   663  //
   664  // Lock MUST NOT be held on entry.
   665  func (c *client) wsHandleProtocolError(message string) error {
   666  	buf := wsCreateCloseMessage(wsCloseStatusProtocolError, message)
   667  	c.wsEnqueueControlMessage(wsCloseMessage, buf)
   668  	nbPoolPut(buf) // wsEnqueueControlMessage has taken a copy.
   669  	return fmt.Errorf(message)
   670  }
   671  
   672  // Create a close message with the given `status` and `body`.
   673  // If the `body` is more than the maximum allows control frame payload size,
   674  // it is truncated and "..." is added at the end (as a hint that message
   675  // is not complete).
   676  func wsCreateCloseMessage(status int, body string) []byte {
   677  	// Since a control message payload is limited in size, we
   678  	// will limit the text and add trailing "..." if truncated.
   679  	// The body of a Close Message must be preceded with 2 bytes,
   680  	// so take that into account for limiting the body length.
   681  	if len(body) > wsMaxControlPayloadSize-2 {
   682  		body = body[:wsMaxControlPayloadSize-5]
   683  		body += "..."
   684  	}
   685  	buf := nbPoolGet(2 + len(body))[:2+len(body)]
   686  	// We need to have a 2 byte unsigned int that represents the error status code
   687  	// https://tools.ietf.org/html/rfc6455#section-5.5.1
   688  	binary.BigEndian.PutUint16(buf[:2], uint16(status))
   689  	copy(buf[2:], []byte(body))
   690  	return buf
   691  }
   692  
   693  // Process websocket client handshake. On success, returns the raw net.Conn that
   694  // will be used to create a *client object.
   695  // Invoked from the HTTP server listening on websocket port.
   696  func (s *Server) wsUpgrade(w http.ResponseWriter, r *http.Request) (*wsUpgradeResult, error) {
   697  	kind := CLIENT
   698  	if r.URL != nil {
   699  		ep := r.URL.EscapedPath()
   700  		if strings.HasSuffix(ep, leafNodeWSPath) {
   701  			kind = LEAF
   702  		} else if strings.HasSuffix(ep, mqttWSPath) {
   703  			kind = MQTT
   704  		}
   705  	}
   706  
   707  	opts := s.getOpts()
   708  
   709  	// From https://tools.ietf.org/html/rfc6455#section-4.2.1
   710  	// Point 1.
   711  	if r.Method != "GET" {
   712  		return nil, wsReturnHTTPError(w, r, http.StatusMethodNotAllowed, "request method must be GET")
   713  	}
   714  	// Point 2.
   715  	if r.Host == _EMPTY_ {
   716  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "'Host' missing in request")
   717  	}
   718  	// Point 3.
   719  	if !wsHeaderContains(r.Header, "Upgrade", "websocket") {
   720  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Upgrade'")
   721  	}
   722  	// Point 4.
   723  	if !wsHeaderContains(r.Header, "Connection", "Upgrade") {
   724  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid value for header 'Connection'")
   725  	}
   726  	// Point 5.
   727  	key := r.Header.Get("Sec-Websocket-Key")
   728  	if key == _EMPTY_ {
   729  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "key missing")
   730  	}
   731  	// Point 6.
   732  	if !wsHeaderContains(r.Header, "Sec-Websocket-Version", "13") {
   733  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "invalid version")
   734  	}
   735  	// Others are optional
   736  	// Point 7.
   737  	if err := s.websocket.checkOrigin(r); err != nil {
   738  		return nil, wsReturnHTTPError(w, r, http.StatusForbidden, fmt.Sprintf("origin not allowed: %v", err))
   739  	}
   740  	// Point 8.
   741  	// We don't have protocols, so ignore.
   742  	// Point 9.
   743  	// Extensions, only support for compression at the moment
   744  	compress := opts.Websocket.Compression
   745  	if compress {
   746  		// Simply check if permessage-deflate extension is present.
   747  		compress, _ = wsPMCExtensionSupport(r.Header, true)
   748  	}
   749  	// We will do masking if asked (unless we reject for tests)
   750  	noMasking := r.Header.Get(wsNoMaskingHeader) == wsNoMaskingValue && !wsTestRejectNoMasking
   751  
   752  	h := w.(http.Hijacker)
   753  	conn, brw, err := h.Hijack()
   754  	if err != nil {
   755  		if conn != nil {
   756  			conn.Close()
   757  		}
   758  		return nil, wsReturnHTTPError(w, r, http.StatusInternalServerError, err.Error())
   759  	}
   760  	if brw.Reader.Buffered() > 0 {
   761  		conn.Close()
   762  		return nil, wsReturnHTTPError(w, r, http.StatusBadRequest, "client sent data before handshake is complete")
   763  	}
   764  
   765  	var buf [1024]byte
   766  	p := buf[:0]
   767  
   768  	// From https://tools.ietf.org/html/rfc6455#section-4.2.2
   769  	p = append(p, "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "...)
   770  	p = append(p, wsAcceptKey(key)...)
   771  	p = append(p, _CRLF_...)
   772  	if compress {
   773  		p = append(p, wsPMCFullResponse...)
   774  	}
   775  	if noMasking {
   776  		p = append(p, wsNoMaskingFullResponse...)
   777  	}
   778  	if kind == MQTT {
   779  		p = append(p, wsMQTTSecProto...)
   780  	}
   781  	if s.websocket.rawHeaders != _EMPTY_ {
   782  		p = append(p, s.websocket.rawHeaders...)
   783  	}
   784  	p = append(p, _CRLF_...)
   785  
   786  	if _, err = conn.Write(p); err != nil {
   787  		conn.Close()
   788  		return nil, err
   789  	}
   790  	// If there was a deadline set for the handshake, clear it now.
   791  	if opts.Websocket.HandshakeTimeout > 0 {
   792  		conn.SetDeadline(time.Time{})
   793  	}
   794  	// Server always expect "clients" to send masked payload, unless the option
   795  	// "no-masking" has been enabled.
   796  	ws := &websocket{compress: compress, maskread: !noMasking}
   797  
   798  	// Check for X-Forwarded-For header
   799  	if cips, ok := r.Header[wsXForwardedForHeader]; ok {
   800  		cip := cips[0]
   801  		if net.ParseIP(cip) != nil {
   802  			ws.clientIP = cip
   803  		}
   804  	}
   805  
   806  	if kind == CLIENT || kind == MQTT {
   807  		// Indicate if this is likely coming from a browser.
   808  		if ua := r.Header.Get("User-Agent"); ua != _EMPTY_ && strings.HasPrefix(ua, "Mozilla/") {
   809  			ws.browser = true
   810  			// Disable fragmentation of compressed frames for Safari browsers.
   811  			// Unfortunately, you could be running Chrome on macOS and this
   812  			// string will contain "Safari/" (along "Chrome/"). However, what
   813  			// I have found is that actual Safari browser also have "Version/".
   814  			// So make the combination of the two.
   815  			ws.nocompfrag = ws.compress && strings.Contains(ua, "Version/") && strings.Contains(ua, "Safari/")
   816  		}
   817  
   818  		if cookies := r.Cookies(); len(cookies) > 0 {
   819  			ows := &opts.Websocket
   820  			for _, c := range cookies {
   821  				if ows.JWTCookie == c.Name {
   822  					ws.cookieJwt = c.Value
   823  				} else if ows.UsernameCookie == c.Name {
   824  					ws.cookieUsername = c.Value
   825  				} else if ows.PasswordCookie == c.Name {
   826  					ws.cookiePassword = c.Value
   827  				} else if ows.TokenCookie == c.Name {
   828  					ws.cookieToken = c.Value
   829  				}
   830  			}
   831  		}
   832  	}
   833  	return &wsUpgradeResult{conn: conn, ws: ws, kind: kind}, nil
   834  }
   835  
   836  // Returns true if the header named `name` contains a token with value `value`.
   837  func wsHeaderContains(header http.Header, name string, value string) bool {
   838  	for _, s := range header[name] {
   839  		tokens := strings.Split(s, ",")
   840  		for _, t := range tokens {
   841  			t = strings.Trim(t, " \t")
   842  			if strings.EqualFold(t, value) {
   843  				return true
   844  			}
   845  		}
   846  	}
   847  	return false
   848  }
   849  
   850  func wsPMCExtensionSupport(header http.Header, checkPMCOnly bool) (bool, bool) {
   851  	for _, extensionList := range header["Sec-Websocket-Extensions"] {
   852  		extensions := strings.Split(extensionList, ",")
   853  		for _, extension := range extensions {
   854  			extension = strings.Trim(extension, " \t")
   855  			params := strings.Split(extension, ";")
   856  			for i, p := range params {
   857  				p = strings.Trim(p, " \t")
   858  				if strings.EqualFold(p, wsPMCExtension) {
   859  					if checkPMCOnly {
   860  						return true, false
   861  					}
   862  					var snc bool
   863  					var cnc bool
   864  					for j := i + 1; j < len(params); j++ {
   865  						p = params[j]
   866  						p = strings.Trim(p, " \t")
   867  						if strings.EqualFold(p, wsPMCSrvNoCtx) {
   868  							snc = true
   869  						} else if strings.EqualFold(p, wsPMCCliNoCtx) {
   870  							cnc = true
   871  						}
   872  						if snc && cnc {
   873  							return true, true
   874  						}
   875  					}
   876  					return true, false
   877  				}
   878  			}
   879  		}
   880  	}
   881  	return false, false
   882  }
   883  
   884  // Send an HTTP error with the given `status` to the given http response writer `w`.
   885  // Return an error created based on the `reason` string.
   886  func wsReturnHTTPError(w http.ResponseWriter, r *http.Request, status int, reason string) error {
   887  	err := fmt.Errorf("%s - websocket handshake error: %s", r.RemoteAddr, reason)
   888  	w.Header().Set("Sec-Websocket-Version", "13")
   889  	http.Error(w, http.StatusText(status), status)
   890  	return err
   891  }
   892  
   893  // If the server is configured to accept any origin, then this function returns
   894  // `nil` without checking if the Origin is present and valid. This is also
   895  // the case if the request does not have the Origin header.
   896  // Otherwise, this will check that the Origin matches the same origin or
   897  // any origin in the allowed list.
   898  func (w *srvWebsocket) checkOrigin(r *http.Request) error {
   899  	w.mu.RLock()
   900  	checkSame := w.sameOrigin
   901  	listEmpty := len(w.allowedOrigins) == 0
   902  	w.mu.RUnlock()
   903  	if !checkSame && listEmpty {
   904  		return nil
   905  	}
   906  	origin := r.Header.Get("Origin")
   907  	if origin == _EMPTY_ {
   908  		origin = r.Header.Get("Sec-Websocket-Origin")
   909  	}
   910  	// If the header is not present, we will accept.
   911  	// From https://datatracker.ietf.org/doc/html/rfc6455#section-1.6
   912  	// "Naturally, when the WebSocket Protocol is used by a dedicated client
   913  	// directly (i.e., not from a web page through a web browser), the origin
   914  	// model is not useful, as the client can provide any arbitrary origin string."
   915  	if origin == _EMPTY_ {
   916  		return nil
   917  	}
   918  	u, err := url.ParseRequestURI(origin)
   919  	if err != nil {
   920  		return err
   921  	}
   922  	oh, op, err := wsGetHostAndPort(u.Scheme == "https", u.Host)
   923  	if err != nil {
   924  		return err
   925  	}
   926  	// If checking same origin, compare with the http's request's Host.
   927  	if checkSame {
   928  		rh, rp, err := wsGetHostAndPort(r.TLS != nil, r.Host)
   929  		if err != nil {
   930  			return err
   931  		}
   932  		if oh != rh || op != rp {
   933  			return errors.New("not same origin")
   934  		}
   935  		// I guess it is possible to have cases where one wants to check
   936  		// same origin, but also that the origin is in the allowed list.
   937  		// So continue with the next check.
   938  	}
   939  	if !listEmpty {
   940  		w.mu.RLock()
   941  		ao := w.allowedOrigins[oh]
   942  		w.mu.RUnlock()
   943  		if ao == nil || u.Scheme != ao.scheme || op != ao.port {
   944  			return errors.New("not in the allowed list")
   945  		}
   946  	}
   947  	return nil
   948  }
   949  
   950  func wsGetHostAndPort(tls bool, hostport string) (string, string, error) {
   951  	host, port, err := net.SplitHostPort(hostport)
   952  	if err != nil {
   953  		// If error is missing port, then use defaults based on the scheme
   954  		if ae, ok := err.(*net.AddrError); ok && strings.Contains(ae.Err, "missing port") {
   955  			err = nil
   956  			host = hostport
   957  			if tls {
   958  				port = "443"
   959  			} else {
   960  				port = "80"
   961  			}
   962  		}
   963  	}
   964  	return strings.ToLower(host), port, err
   965  }
   966  
   967  // Concatenate the key sent by the client with the GUID, then computes the SHA1 hash
   968  // and returns it as a based64 encoded string.
   969  func wsAcceptKey(key string) string {
   970  	h := sha1.New()
   971  	h.Write([]byte(key))
   972  	h.Write(wsGUID)
   973  	return base64.StdEncoding.EncodeToString(h.Sum(nil))
   974  }
   975  
   976  func wsMakeChallengeKey() (string, error) {
   977  	p := make([]byte, 16)
   978  	if _, err := io.ReadFull(crand.Reader, p); err != nil {
   979  		return _EMPTY_, err
   980  	}
   981  	return base64.StdEncoding.EncodeToString(p), nil
   982  }
   983  
   984  // Validate the websocket related options.
   985  func validateWebsocketOptions(o *Options) error {
   986  	wo := &o.Websocket
   987  	// If no port is defined, we don't care about other options
   988  	if wo.Port == 0 {
   989  		return nil
   990  	}
   991  	// Enforce TLS... unless NoTLS is set to true.
   992  	if wo.TLSConfig == nil && !wo.NoTLS {
   993  		return errors.New("websocket requires TLS configuration")
   994  	}
   995  	// Make sure that allowed origins, if specified, can be parsed.
   996  	for _, ao := range wo.AllowedOrigins {
   997  		if _, err := url.Parse(ao); err != nil {
   998  			return fmt.Errorf("unable to parse allowed origin: %v", err)
   999  		}
  1000  	}
  1001  	// If there is a NoAuthUser, we need to have Users defined and
  1002  	// the user to be present.
  1003  	if wo.NoAuthUser != _EMPTY_ {
  1004  		if err := validateNoAuthUser(o, wo.NoAuthUser); err != nil {
  1005  			return err
  1006  		}
  1007  	}
  1008  	// Token/Username not possible if there are users/nkeys
  1009  	if len(o.Users) > 0 || len(o.Nkeys) > 0 {
  1010  		if wo.Username != _EMPTY_ {
  1011  			return fmt.Errorf("websocket authentication username not compatible with presence of users/nkeys")
  1012  		}
  1013  		if wo.Token != _EMPTY_ {
  1014  			return fmt.Errorf("websocket authentication token not compatible with presence of users/nkeys")
  1015  		}
  1016  	}
  1017  	// Using JWT requires Trusted Keys
  1018  	if wo.JWTCookie != _EMPTY_ {
  1019  		if len(o.TrustedOperators) == 0 && len(o.TrustedKeys) == 0 {
  1020  			return fmt.Errorf("trusted operators or trusted keys configuration is required for JWT authentication via cookie %q", wo.JWTCookie)
  1021  		}
  1022  	}
  1023  	if err := validatePinnedCerts(wo.TLSPinnedCerts); err != nil {
  1024  		return fmt.Errorf("websocket: %v", err)
  1025  	}
  1026  
  1027  	// Check for invalid headers here.
  1028  	for key := range wo.Headers {
  1029  		k := strings.ToLower(key)
  1030  		switch k {
  1031  		case "host",
  1032  			"content-length",
  1033  			"connection",
  1034  			"upgrade",
  1035  			"nats-no-masking":
  1036  			return fmt.Errorf("websocket: invalid header %q not allowed", key)
  1037  		}
  1038  
  1039  		if strings.HasPrefix(k, "sec-websocket-") {
  1040  			return fmt.Errorf("websocket: invalid header %q, \"Sec-WebSocket-\" prefix not allowed", key)
  1041  		}
  1042  	}
  1043  
  1044  	return nil
  1045  }
  1046  
  1047  // Creates or updates the existing map
  1048  func (s *Server) wsSetOriginOptions(o *WebsocketOpts) {
  1049  	ws := &s.websocket
  1050  	ws.mu.Lock()
  1051  	defer ws.mu.Unlock()
  1052  	// Copy over the option's same origin boolean
  1053  	ws.sameOrigin = o.SameOrigin
  1054  	// Reset the map. Will help for config reload if/when we support it.
  1055  	ws.allowedOrigins = nil
  1056  	if o.AllowedOrigins == nil {
  1057  		return
  1058  	}
  1059  	for _, ao := range o.AllowedOrigins {
  1060  		// We have previously checked (during options validation) that the urls
  1061  		// are parseable, but if we get an error, report and skip.
  1062  		u, err := url.ParseRequestURI(ao)
  1063  		if err != nil {
  1064  			s.Errorf("error parsing allowed origin: %v", err)
  1065  			continue
  1066  		}
  1067  		h, p, _ := wsGetHostAndPort(u.Scheme == "https", u.Host)
  1068  		if ws.allowedOrigins == nil {
  1069  			ws.allowedOrigins = make(map[string]*allowedOrigin, len(o.AllowedOrigins))
  1070  		}
  1071  		ws.allowedOrigins[h] = &allowedOrigin{scheme: u.Scheme, port: p}
  1072  	}
  1073  }
  1074  
  1075  // Calculate the raw headers for websocket upgrade response.
  1076  func (s *Server) wsSetHeadersOptions(o *WebsocketOpts) {
  1077  	var sb strings.Builder
  1078  	for k, v := range o.Headers {
  1079  		sb.WriteString(k)
  1080  		sb.WriteString(": ")
  1081  		sb.WriteString(v)
  1082  		sb.WriteString(_CRLF_)
  1083  	}
  1084  	ws := &s.websocket
  1085  	ws.mu.Lock()
  1086  	defer ws.mu.Unlock()
  1087  	ws.rawHeaders = sb.String()
  1088  }
  1089  
  1090  // Given the websocket options, we check if any auth configuration
  1091  // has been provided. If so, possibly create users/nkey users and
  1092  // store them in s.websocket.users/nkeys.
  1093  // Also update a boolean that indicates if auth is required for
  1094  // websocket clients.
  1095  // Server lock is held on entry.
  1096  func (s *Server) wsConfigAuth(opts *WebsocketOpts) {
  1097  	ws := &s.websocket
  1098  	// If any of those is specified, we consider that there is an override.
  1099  	ws.authOverride = opts.Username != _EMPTY_ || opts.Token != _EMPTY_ || opts.NoAuthUser != _EMPTY_
  1100  }
  1101  
  1102  func (s *Server) startWebsocketServer() {
  1103  	if s.isShuttingDown() {
  1104  		return
  1105  	}
  1106  
  1107  	sopts := s.getOpts()
  1108  	o := &sopts.Websocket
  1109  
  1110  	s.wsSetOriginOptions(o)
  1111  	s.wsSetHeadersOptions(o)
  1112  
  1113  	var hl net.Listener
  1114  	var proto string
  1115  	var err error
  1116  
  1117  	port := o.Port
  1118  	if port == -1 {
  1119  		port = 0
  1120  	}
  1121  	hp := net.JoinHostPort(o.Host, strconv.Itoa(port))
  1122  
  1123  	// We are enforcing (when validating the options) the use of TLS, but the
  1124  	// code was originally supporting both modes. The reason for TLS only is
  1125  	// that we expect users to send JWTs with bearer tokens and we want to
  1126  	// avoid the possibility of it being "intercepted".
  1127  
  1128  	s.mu.Lock()
  1129  	// Do not check o.NoTLS here. If a TLS configuration is available, use it,
  1130  	// regardless of NoTLS. If we don't have a TLS config, it means that the
  1131  	// user has configured NoTLS because otherwise the server would have failed
  1132  	// to start due to options validation.
  1133  	if o.TLSConfig != nil {
  1134  		proto = wsSchemePrefixTLS
  1135  		config := o.TLSConfig.Clone()
  1136  		config.GetConfigForClient = s.wsGetTLSConfig
  1137  		hl, err = tls.Listen("tcp", hp, config)
  1138  	} else {
  1139  		proto = wsSchemePrefix
  1140  		hl, err = net.Listen("tcp", hp)
  1141  	}
  1142  	s.websocket.listenerErr = err
  1143  	if err != nil {
  1144  		s.mu.Unlock()
  1145  		s.Fatalf("Unable to listen for websocket connections: %v", err)
  1146  		return
  1147  	}
  1148  	if port == 0 {
  1149  		o.Port = hl.Addr().(*net.TCPAddr).Port
  1150  	}
  1151  	s.Noticef("Listening for websocket clients on %s://%s:%d", proto, o.Host, o.Port)
  1152  	if proto == wsSchemePrefix {
  1153  		s.Warnf("Websocket not configured with TLS. DO NOT USE IN PRODUCTION!")
  1154  	}
  1155  
  1156  	s.websocket.tls = proto == "wss"
  1157  	s.websocket.connectURLs, err = s.getConnectURLs(o.Advertise, o.Host, o.Port)
  1158  	if err != nil {
  1159  		s.Fatalf("Unable to get websocket connect URLs: %v", err)
  1160  		hl.Close()
  1161  		s.mu.Unlock()
  1162  		return
  1163  	}
  1164  	hasLeaf := sopts.LeafNode.Port != 0
  1165  	mux := http.NewServeMux()
  1166  	mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
  1167  		res, err := s.wsUpgrade(w, r)
  1168  		if err != nil {
  1169  			s.Errorf(err.Error())
  1170  			return
  1171  		}
  1172  		switch res.kind {
  1173  		case CLIENT:
  1174  			s.createWSClient(res.conn, res.ws)
  1175  		case MQTT:
  1176  			s.createMQTTClient(res.conn, res.ws)
  1177  		case LEAF:
  1178  			if !hasLeaf {
  1179  				s.Errorf("Not configured to accept leaf node connections")
  1180  				// Silently close for now. If we want to send an error back, we would
  1181  				// need to create the leafnode client anyway, so that is handling websocket
  1182  				// frames, then send the error to the remote.
  1183  				res.conn.Close()
  1184  				return
  1185  			}
  1186  			s.createLeafNode(res.conn, nil, nil, res.ws)
  1187  		}
  1188  	})
  1189  	hs := &http.Server{
  1190  		Addr:        hp,
  1191  		Handler:     mux,
  1192  		ReadTimeout: o.HandshakeTimeout,
  1193  		ErrorLog:    log.New(&captureHTTPServerLog{s, "websocket: "}, _EMPTY_, 0),
  1194  	}
  1195  	s.websocket.server = hs
  1196  	s.websocket.listener = hl
  1197  	go func() {
  1198  		if err := hs.Serve(hl); err != http.ErrServerClosed {
  1199  			s.Fatalf("websocket listener error: %v", err)
  1200  		}
  1201  		if s.isLameDuckMode() {
  1202  			// Signal that we are not accepting new clients
  1203  			s.ldmCh <- true
  1204  			// Now wait for the Shutdown...
  1205  			<-s.quitCh
  1206  			return
  1207  		}
  1208  		s.done <- true
  1209  	}()
  1210  	s.mu.Unlock()
  1211  }
  1212  
  1213  // The TLS configuration is passed to the listener when the websocket
  1214  // "server" is setup. That prevents TLS configuration updates on reload
  1215  // from being used. By setting this function in tls.Config.GetConfigForClient
  1216  // we instruct the TLS handshake to ask for the tls configuration to be
  1217  // used for a specific client. We don't care which client, we always use
  1218  // the same TLS configuration.
  1219  func (s *Server) wsGetTLSConfig(_ *tls.ClientHelloInfo) (*tls.Config, error) {
  1220  	opts := s.getOpts()
  1221  	return opts.Websocket.TLSConfig, nil
  1222  }
  1223  
  1224  // This is similar to createClient() but has some modifications
  1225  // specific to handle websocket clients.
  1226  // The comments have been kept to minimum to reduce code size.
  1227  // Check createClient() for more details.
  1228  func (s *Server) createWSClient(conn net.Conn, ws *websocket) *client {
  1229  	opts := s.getOpts()
  1230  
  1231  	maxPay := int32(opts.MaxPayload)
  1232  	maxSubs := int32(opts.MaxSubs)
  1233  	if maxSubs == 0 {
  1234  		maxSubs = -1
  1235  	}
  1236  	now := time.Now().UTC()
  1237  
  1238  	c := &client{srv: s, nc: conn, opts: defaultOpts, mpay: maxPay, msubs: maxSubs, start: now, last: now, ws: ws}
  1239  
  1240  	c.registerWithAccount(s.globalAccount())
  1241  
  1242  	var info Info
  1243  	var authRequired bool
  1244  
  1245  	s.mu.Lock()
  1246  	info = s.copyInfo()
  1247  	// Check auth, override if applicable.
  1248  	if !info.AuthRequired {
  1249  		// Set info.AuthRequired since this is what is sent to the client.
  1250  		info.AuthRequired = s.websocket.authOverride
  1251  	}
  1252  	if s.nonceRequired() {
  1253  		var raw [nonceLen]byte
  1254  		nonce := raw[:]
  1255  		s.generateNonce(nonce)
  1256  		info.Nonce = string(nonce)
  1257  	}
  1258  	c.nonce = []byte(info.Nonce)
  1259  	authRequired = info.AuthRequired
  1260  
  1261  	s.totalClients++
  1262  	s.mu.Unlock()
  1263  
  1264  	c.mu.Lock()
  1265  	if authRequired {
  1266  		c.flags.set(expectConnect)
  1267  	}
  1268  	c.initClient()
  1269  	c.Debugf("Client connection created")
  1270  	c.sendProtoNow(c.generateClientInfoJSON(info))
  1271  	c.mu.Unlock()
  1272  
  1273  	s.mu.Lock()
  1274  	if !s.isRunning() || s.ldm {
  1275  		if s.isShuttingDown() {
  1276  			conn.Close()
  1277  		}
  1278  		s.mu.Unlock()
  1279  		return c
  1280  	}
  1281  
  1282  	if opts.MaxConn > 0 && len(s.clients) >= opts.MaxConn {
  1283  		s.mu.Unlock()
  1284  		c.maxConnExceeded()
  1285  		return nil
  1286  	}
  1287  	s.clients[c.cid] = c
  1288  	s.mu.Unlock()
  1289  
  1290  	c.mu.Lock()
  1291  	// Websocket clients do TLS in the websocket http server.
  1292  	// So no TLS initiation here...
  1293  	if _, ok := conn.(*tls.Conn); ok {
  1294  		c.flags.set(handshakeComplete)
  1295  	}
  1296  
  1297  	if c.isClosed() {
  1298  		c.mu.Unlock()
  1299  		c.closeConnection(WriteError)
  1300  		return nil
  1301  	}
  1302  
  1303  	if authRequired {
  1304  		timeout := opts.AuthTimeout
  1305  		// Possibly override with Websocket specific value.
  1306  		if opts.Websocket.AuthTimeout != 0 {
  1307  			timeout = opts.Websocket.AuthTimeout
  1308  		}
  1309  		c.setAuthTimer(secondsToDuration(timeout))
  1310  	}
  1311  
  1312  	c.setPingTimer()
  1313  
  1314  	s.startGoRoutine(func() { c.readLoop(nil) })
  1315  	s.startGoRoutine(func() { c.writeLoop() })
  1316  
  1317  	c.mu.Unlock()
  1318  
  1319  	return c
  1320  }
  1321  
  1322  func (c *client) wsCollapsePtoNB() (net.Buffers, int64) {
  1323  	nb := c.out.nb
  1324  	var mfs int
  1325  	var usz int
  1326  	if c.ws.browser {
  1327  		mfs = wsFrameSizeForBrowsers
  1328  	}
  1329  	mask := c.ws.maskwrite
  1330  	// Start with possible already framed buffers (that we could have
  1331  	// got from partials or control messages such as ws pings or pongs).
  1332  	bufs := c.ws.frames
  1333  	compress := c.ws.compress
  1334  	if compress && len(nb) > 0 {
  1335  		// First, make sure we don't compress for very small cumulative buffers.
  1336  		for _, b := range nb {
  1337  			usz += len(b)
  1338  		}
  1339  		if usz <= wsCompressThreshold {
  1340  			compress = false
  1341  		}
  1342  	}
  1343  	if compress && len(nb) > 0 {
  1344  		// Overwrite mfs if this connection does not support fragmented compressed frames.
  1345  		if mfs > 0 && c.ws.nocompfrag {
  1346  			mfs = 0
  1347  		}
  1348  		buf := bytes.NewBuffer(nbPoolGet(usz))
  1349  		cp := c.ws.compressor
  1350  		if cp == nil {
  1351  			c.ws.compressor, _ = flate.NewWriter(buf, flate.BestSpeed)
  1352  			cp = c.ws.compressor
  1353  		} else {
  1354  			cp.Reset(buf)
  1355  		}
  1356  		var csz int
  1357  		for _, b := range nb {
  1358  			cp.Write(b)
  1359  			nbPoolPut(b) // No longer needed as contents written to compressor.
  1360  		}
  1361  		if err := cp.Flush(); err != nil {
  1362  			c.Errorf("Error during compression: %v", err)
  1363  			c.markConnAsClosed(WriteError)
  1364  			return nil, 0
  1365  		}
  1366  		b := buf.Bytes()
  1367  		p := b[:len(b)-4]
  1368  		if mfs > 0 && len(p) > mfs {
  1369  			for first, final := true, false; len(p) > 0; first = false {
  1370  				lp := len(p)
  1371  				if lp > mfs {
  1372  					lp = mfs
  1373  				} else {
  1374  					final = true
  1375  				}
  1376  				// Only the first frame should be marked as compressed, so pass
  1377  				// `first` for the compressed boolean.
  1378  				fh := nbPoolGet(wsMaxFrameHeaderSize)[:wsMaxFrameHeaderSize]
  1379  				n, key := wsFillFrameHeader(fh, mask, first, final, first, wsBinaryMessage, lp)
  1380  				if mask {
  1381  					wsMaskBuf(key, p[:lp])
  1382  				}
  1383  				bufs = append(bufs, fh[:n], p[:lp])
  1384  				csz += n + lp
  1385  				p = p[lp:]
  1386  			}
  1387  		} else {
  1388  			ol := len(p)
  1389  			h, key := wsCreateFrameHeader(mask, true, wsBinaryMessage, ol)
  1390  			if mask {
  1391  				wsMaskBuf(key, p)
  1392  			}
  1393  			if ol > 0 {
  1394  				bufs = append(bufs, h, p)
  1395  			}
  1396  			csz = len(h) + ol
  1397  		}
  1398  		// Make sure that the compressor no longer holds a reference to
  1399  		// the bytes.Buffer, so that the underlying memory gets cleaned
  1400  		// up after flushOutbound/flushAndClose. For this to be safe, we
  1401  		// always cp.Reset(...) before reusing the compressor again.
  1402  		cp.Reset(nil)
  1403  		// Add to pb the compressed data size (including headers), but
  1404  		// remove the original uncompressed data size that was added
  1405  		// during the queueing.
  1406  		c.out.pb += int64(csz) - int64(usz)
  1407  		c.ws.fs += int64(csz)
  1408  	} else if len(nb) > 0 {
  1409  		var total int
  1410  		if mfs > 0 {
  1411  			// We are limiting the frame size.
  1412  			startFrame := func() int {
  1413  				bufs = append(bufs, nbPoolGet(wsMaxFrameHeaderSize))
  1414  				return len(bufs) - 1
  1415  			}
  1416  			endFrame := func(idx, size int) {
  1417  				bufs[idx] = bufs[idx][:wsMaxFrameHeaderSize]
  1418  				n, key := wsFillFrameHeader(bufs[idx], mask, wsFirstFrame, wsFinalFrame, wsUncompressedFrame, wsBinaryMessage, size)
  1419  				bufs[idx] = bufs[idx][:n]
  1420  				c.out.pb += int64(n)
  1421  				c.ws.fs += int64(n + size)
  1422  				if mask {
  1423  					wsMaskBufs(key, bufs[idx+1:])
  1424  				}
  1425  			}
  1426  
  1427  			fhIdx := startFrame()
  1428  			for i := 0; i < len(nb); i++ {
  1429  				b := nb[i]
  1430  				if total+len(b) <= mfs {
  1431  					buf := nbPoolGet(len(b))
  1432  					bufs = append(bufs, append(buf, b...))
  1433  					total += len(b)
  1434  					nbPoolPut(nb[i])
  1435  					continue
  1436  				}
  1437  				for len(b) > 0 {
  1438  					endStart := total != 0
  1439  					if endStart {
  1440  						endFrame(fhIdx, total)
  1441  					}
  1442  					total = len(b)
  1443  					if total >= mfs {
  1444  						total = mfs
  1445  					}
  1446  					if endStart {
  1447  						fhIdx = startFrame()
  1448  					}
  1449  					buf := nbPoolGet(total)
  1450  					bufs = append(bufs, append(buf, b[:total]...))
  1451  					b = b[total:]
  1452  				}
  1453  				nbPoolPut(nb[i]) // No longer needed as copied into smaller frames.
  1454  			}
  1455  			if total > 0 {
  1456  				endFrame(fhIdx, total)
  1457  			}
  1458  		} else {
  1459  			// If there is no limit on the frame size, create a single frame for
  1460  			// all pending buffers.
  1461  			for _, b := range nb {
  1462  				total += len(b)
  1463  			}
  1464  			wsfh, key := wsCreateFrameHeader(mask, false, wsBinaryMessage, total)
  1465  			c.out.pb += int64(len(wsfh))
  1466  			bufs = append(bufs, wsfh)
  1467  			idx := len(bufs)
  1468  			bufs = append(bufs, nb...)
  1469  			if mask {
  1470  				wsMaskBufs(key, bufs[idx:])
  1471  			}
  1472  			c.ws.fs += int64(len(wsfh) + total)
  1473  		}
  1474  	}
  1475  	if len(c.ws.closeMsg) > 0 {
  1476  		bufs = append(bufs, c.ws.closeMsg)
  1477  		c.ws.fs += int64(len(c.ws.closeMsg))
  1478  		c.ws.closeMsg = nil
  1479  	}
  1480  	c.ws.frames = nil
  1481  	return bufs, c.ws.fs
  1482  }
  1483  
  1484  func isWSURL(u *url.URL) bool {
  1485  	return strings.HasPrefix(strings.ToLower(u.Scheme), wsSchemePrefix)
  1486  }
  1487  
  1488  func isWSSURL(u *url.URL) bool {
  1489  	return strings.HasPrefix(strings.ToLower(u.Scheme), wsSchemePrefixTLS)
  1490  }