github.com/nats-io/nats-server/v2@v2.11.0-preview.2/server/websocket_test.go (about)

     1  // Copyright 2020-2024 The NATS Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package server
    15  
    16  import (
    17  	"bufio"
    18  	"bytes"
    19  	"crypto/tls"
    20  	"encoding/base64"
    21  	"encoding/binary"
    22  	"encoding/json"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"math/rand"
    27  	"net"
    28  	"net/http"
    29  	"net/url"
    30  	"reflect"
    31  	"strings"
    32  	"sync"
    33  	"sync/atomic"
    34  	"testing"
    35  	"time"
    36  
    37  	"github.com/nats-io/jwt/v2"
    38  	"github.com/nats-io/nats.go"
    39  	"github.com/nats-io/nkeys"
    40  
    41  	"github.com/klauspost/compress/flate"
    42  )
    43  
    44  type testReader struct {
    45  	buf []byte
    46  	pos int
    47  	max int
    48  	err error
    49  }
    50  
    51  func (tr *testReader) Read(p []byte) (int, error) {
    52  	if tr.err != nil {
    53  		return 0, tr.err
    54  	}
    55  	n := len(tr.buf) - tr.pos
    56  	if n == 0 {
    57  		return 0, nil
    58  	}
    59  	if n > len(p) {
    60  		n = len(p)
    61  	}
    62  	if tr.max > 0 && n > tr.max {
    63  		n = tr.max
    64  	}
    65  	copy(p, tr.buf[tr.pos:tr.pos+n])
    66  	tr.pos += n
    67  	return n, nil
    68  }
    69  
    70  func TestWSGet(t *testing.T) {
    71  	rb := []byte("012345")
    72  
    73  	tr := &testReader{buf: []byte("6789")}
    74  
    75  	for _, test := range []struct {
    76  		name   string
    77  		pos    int
    78  		needed int
    79  		newpos int
    80  		trmax  int
    81  		result string
    82  		reterr bool
    83  	}{
    84  		{"fromrb1", 0, 3, 3, 4, "012", false},    // Partial from read buffer
    85  		{"fromrb2", 3, 2, 5, 4, "34", false},     // Partial from read buffer
    86  		{"fromrb3", 5, 1, 6, 4, "5", false},      // Partial from read buffer
    87  		{"fromtr1", 4, 4, 6, 4, "4567", false},   // Partial from read buffer + some of ioReader
    88  		{"fromtr2", 4, 6, 6, 4, "456789", false}, // Partial from read buffer + all of ioReader
    89  		{"fromtr3", 4, 6, 6, 2, "456789", false}, // Partial from read buffer + all of ioReader with several reads
    90  		{"fromtr4", 4, 6, 6, 2, "", true},        // ioReader returns error
    91  	} {
    92  		t.Run(test.name, func(t *testing.T) {
    93  			tr.pos = 0
    94  			tr.max = test.trmax
    95  			if test.reterr {
    96  				tr.err = fmt.Errorf("on purpose")
    97  			}
    98  			res, np, err := wsGet(tr, rb, test.pos, test.needed)
    99  			if test.reterr {
   100  				if err == nil {
   101  					t.Fatalf("Expected error, got none")
   102  				}
   103  				if err.Error() != "on purpose" {
   104  					t.Fatalf("Unexpected error: %v", err)
   105  				}
   106  				if np != 0 || res != nil {
   107  					t.Fatalf("Unexpected returned values: res=%v n=%v", res, np)
   108  				}
   109  				return
   110  			}
   111  			if err != nil {
   112  				t.Fatalf("Error on get: %v", err)
   113  			}
   114  			if np != test.newpos {
   115  				t.Fatalf("Expected pos=%v, got %v", test.newpos, np)
   116  			}
   117  			if string(res) != test.result {
   118  				t.Fatalf("Invalid returned content: %s", res)
   119  			}
   120  		})
   121  	}
   122  }
   123  
   124  func TestWSIsControlFrame(t *testing.T) {
   125  	for _, test := range []struct {
   126  		name      string
   127  		code      wsOpCode
   128  		isControl bool
   129  	}{
   130  		{"binary", wsBinaryMessage, false},
   131  		{"text", wsTextMessage, false},
   132  		{"ping", wsPingMessage, true},
   133  		{"pong", wsPongMessage, true},
   134  		{"close", wsCloseMessage, true},
   135  	} {
   136  		t.Run(test.name, func(t *testing.T) {
   137  			if res := wsIsControlFrame(test.code); res != test.isControl {
   138  				t.Fatalf("Expected %q isControl to be %v, got %v", test.name, test.isControl, res)
   139  			}
   140  		})
   141  	}
   142  }
   143  
   144  func testWSSimpleMask(key, buf []byte) {
   145  	for i := 0; i < len(buf); i++ {
   146  		buf[i] ^= key[i&3]
   147  	}
   148  }
   149  
   150  func TestWSUnmask(t *testing.T) {
   151  	key := []byte{1, 2, 3, 4}
   152  	orgBuf := []byte("this is a clear text")
   153  
   154  	mask := func() []byte {
   155  		t.Helper()
   156  		buf := append([]byte(nil), orgBuf...)
   157  		testWSSimpleMask(key, buf)
   158  		// First ensure that the content is masked.
   159  		if bytes.Equal(buf, orgBuf) {
   160  			t.Fatalf("Masking did not do anything: %q", buf)
   161  		}
   162  		return buf
   163  	}
   164  
   165  	ri := &wsReadInfo{mask: true}
   166  	ri.init()
   167  	copy(ri.mkey[:], key)
   168  
   169  	buf := mask()
   170  	// Unmask in one call
   171  	ri.unmask(buf)
   172  	if !bytes.Equal(buf, orgBuf) {
   173  		t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf)
   174  	}
   175  
   176  	// Unmask in multiple calls
   177  	buf = mask()
   178  	ri.mkpos = 0
   179  	ri.unmask(buf[:3])
   180  	ri.unmask(buf[3:11])
   181  	ri.unmask(buf[11:])
   182  	if !bytes.Equal(buf, orgBuf) {
   183  		t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf)
   184  	}
   185  }
   186  
   187  func TestWSCreateCloseMessage(t *testing.T) {
   188  	for _, test := range []struct {
   189  		name      string
   190  		status    int
   191  		psize     int
   192  		truncated bool
   193  	}{
   194  		{"fits", wsCloseStatusInternalSrvError, 10, false},
   195  		{"truncated", wsCloseStatusProtocolError, wsMaxControlPayloadSize + 10, true},
   196  	} {
   197  		t.Run(test.name, func(t *testing.T) {
   198  			payload := make([]byte, test.psize)
   199  			for i := 0; i < len(payload); i++ {
   200  				payload[i] = byte('A' + (i % 26))
   201  			}
   202  			res := wsCreateCloseMessage(test.status, string(payload))
   203  			if status := binary.BigEndian.Uint16(res[:2]); int(status) != test.status {
   204  				t.Fatalf("Expected status to be %v, got %v", test.status, status)
   205  			}
   206  			psize := len(res) - 2
   207  			if !test.truncated {
   208  				if int(psize) != test.psize {
   209  					t.Fatalf("Expected size to be %v, got %v", test.psize, psize)
   210  				}
   211  				if !bytes.Equal(res[2:], payload) {
   212  					t.Fatalf("Unexpected result: %q", res[2:])
   213  				}
   214  				return
   215  			}
   216  			// Since the payload of a close message contains a 2 byte status, the
   217  			// actual max text size will be wsMaxControlPayloadSize-2
   218  			if int(psize) != wsMaxControlPayloadSize-2 {
   219  				t.Fatalf("Expected size to be capped to %v, got %v", wsMaxControlPayloadSize-2, psize)
   220  			}
   221  			if string(res[len(res)-3:]) != "..." {
   222  				t.Fatalf("Expected res to have `...` at the end, got %q", res[4:])
   223  			}
   224  		})
   225  	}
   226  }
   227  
   228  func TestWSCreateFrameHeader(t *testing.T) {
   229  	for _, test := range []struct {
   230  		name       string
   231  		frameType  wsOpCode
   232  		compressed bool
   233  		len        int
   234  	}{
   235  		{"uncompressed 10", wsBinaryMessage, false, 10},
   236  		{"uncompressed 600", wsTextMessage, false, 600},
   237  		{"uncompressed 100000", wsTextMessage, false, 100000},
   238  		{"compressed 10", wsBinaryMessage, true, 10},
   239  		{"compressed 600", wsBinaryMessage, true, 600},
   240  		{"compressed 100000", wsTextMessage, true, 100000},
   241  	} {
   242  		t.Run(test.name, func(t *testing.T) {
   243  			res, _ := wsCreateFrameHeader(false, test.compressed, test.frameType, test.len)
   244  			// The server is always sending the message has a single frame,
   245  			// so the "final" bit should be set.
   246  			expected := byte(test.frameType) | wsFinalBit
   247  			if test.compressed {
   248  				expected |= wsRsv1Bit
   249  			}
   250  			if b := res[0]; b != expected {
   251  				t.Fatalf("Expected first byte to be %v, got %v", expected, b)
   252  			}
   253  			switch {
   254  			case test.len <= 125:
   255  				if len(res) != 2 {
   256  					t.Fatalf("Frame len should be 2, got %v", len(res))
   257  				}
   258  				if res[1] != byte(test.len) {
   259  					t.Fatalf("Expected len to be in second byte and be %v, got %v", test.len, res[1])
   260  				}
   261  			case test.len < 65536:
   262  				// 1+1+2
   263  				if len(res) != 4 {
   264  					t.Fatalf("Frame len should be 4, got %v", len(res))
   265  				}
   266  				if res[1] != 126 {
   267  					t.Fatalf("Second byte value should be 126, got %v", res[1])
   268  				}
   269  				if rl := binary.BigEndian.Uint16(res[2:]); int(rl) != test.len {
   270  					t.Fatalf("Expected len to be %v, got %v", test.len, rl)
   271  				}
   272  			default:
   273  				// 1+1+8
   274  				if len(res) != 10 {
   275  					t.Fatalf("Frame len should be 10, got %v", len(res))
   276  				}
   277  				if res[1] != 127 {
   278  					t.Fatalf("Second byte value should be 127, got %v", res[1])
   279  				}
   280  				if rl := binary.BigEndian.Uint64(res[2:]); int(rl) != test.len {
   281  					t.Fatalf("Expected len to be %v, got %v", test.len, rl)
   282  				}
   283  			}
   284  		})
   285  	}
   286  }
   287  
   288  func testWSCreateClientMsg(frameType wsOpCode, frameNum int, final, compressed bool, payload []byte) []byte {
   289  	if compressed {
   290  		buf := &bytes.Buffer{}
   291  		compressor, _ := flate.NewWriter(buf, 1)
   292  		compressor.Write(payload)
   293  		compressor.Flush()
   294  		payload = buf.Bytes()
   295  		// The last 4 bytes are dropped
   296  		payload = payload[:len(payload)-4]
   297  	}
   298  	frame := make([]byte, 14+len(payload))
   299  	if frameNum == 1 {
   300  		frame[0] = byte(frameType)
   301  	}
   302  	if final {
   303  		frame[0] |= wsFinalBit
   304  	}
   305  	if compressed {
   306  		frame[0] |= wsRsv1Bit
   307  	}
   308  	pos := 1
   309  	lenPayload := len(payload)
   310  	switch {
   311  	case lenPayload <= 125:
   312  		frame[pos] = byte(lenPayload) | wsMaskBit
   313  		pos++
   314  	case lenPayload < 65536:
   315  		frame[pos] = 126 | wsMaskBit
   316  		binary.BigEndian.PutUint16(frame[2:], uint16(lenPayload))
   317  		pos += 3
   318  	default:
   319  		frame[1] = 127 | wsMaskBit
   320  		binary.BigEndian.PutUint64(frame[2:], uint64(lenPayload))
   321  		pos += 9
   322  	}
   323  	key := []byte{1, 2, 3, 4}
   324  	copy(frame[pos:], key)
   325  	pos += 4
   326  	copy(frame[pos:], payload)
   327  	testWSSimpleMask(key, frame[pos:])
   328  	pos += lenPayload
   329  	return frame[:pos]
   330  }
   331  
   332  func testWSSetupForRead() (*client, *wsReadInfo, *testReader) {
   333  	ri := &wsReadInfo{mask: true}
   334  	ri.init()
   335  	tr := &testReader{}
   336  	opts := DefaultOptions()
   337  	opts.MaxPending = MAX_PENDING_SIZE
   338  	s := &Server{opts: opts}
   339  	c := &client{srv: s, ws: &websocket{}}
   340  	c.initClient()
   341  	return c, ri, tr
   342  }
   343  
   344  func TestWSReadUncompressedFrames(t *testing.T) {
   345  	c, ri, tr := testWSSetupForRead()
   346  	// Create 2 WS messages
   347  	pl1 := []byte("first message")
   348  	wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl1)
   349  	pl2 := []byte("second message")
   350  	wsmsg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl2)
   351  	// Add both in single buffer
   352  	orgrb := append([]byte(nil), wsmsg1...)
   353  	orgrb = append(orgrb, wsmsg2...)
   354  
   355  	rb := append([]byte(nil), orgrb...)
   356  	bufs, err := c.wsRead(ri, tr, rb)
   357  	if err != nil {
   358  		t.Fatalf("Unexpected error: %v", err)
   359  	}
   360  	if n := len(bufs); n != 2 {
   361  		t.Fatalf("Expected 2 buffers, got %v", n)
   362  	}
   363  	if !bytes.Equal(bufs[0], pl1) {
   364  		t.Fatalf("Unexpected content for buffer 1: %s", bufs[0])
   365  	}
   366  	if !bytes.Equal(bufs[1], pl2) {
   367  		t.Fatalf("Unexpected content for buffer 2: %s", bufs[1])
   368  	}
   369  
   370  	// Now reset and try with the read buffer not containing full ws frame
   371  	c, ri, tr = testWSSetupForRead()
   372  	rb = append([]byte(nil), orgrb...)
   373  	// Frame is 1+1+4+'first message'. So say we pass with rb of 11 bytes,
   374  	// then we should get "first"
   375  	bufs, err = c.wsRead(ri, tr, rb[:11])
   376  	if err != nil {
   377  		t.Fatalf("Unexpected error: %v", err)
   378  	}
   379  	if n := len(bufs); n != 1 {
   380  		t.Fatalf("Unexpected buffer returned: %v", n)
   381  	}
   382  	if string(bufs[0]) != "first" {
   383  		t.Fatalf("Unexpected content: %q", bufs[0])
   384  	}
   385  	// Call again with more data..
   386  	bufs, err = c.wsRead(ri, tr, rb[11:32])
   387  	if err != nil {
   388  		t.Fatalf("Unexpected error: %v", err)
   389  	}
   390  	if n := len(bufs); n != 2 {
   391  		t.Fatalf("Unexpected buffer returned: %v", n)
   392  	}
   393  	if string(bufs[0]) != " message" {
   394  		t.Fatalf("Unexpected content: %q", bufs[0])
   395  	}
   396  	if string(bufs[1]) != "second " {
   397  		t.Fatalf("Unexpected content: %q", bufs[1])
   398  	}
   399  	// Call with the rest
   400  	bufs, err = c.wsRead(ri, tr, rb[32:])
   401  	if err != nil {
   402  		t.Fatalf("Unexpected error: %v", err)
   403  	}
   404  	if n := len(bufs); n != 1 {
   405  		t.Fatalf("Unexpected buffer returned: %v", n)
   406  	}
   407  	if string(bufs[0]) != "message" {
   408  		t.Fatalf("Unexpected content: %q", bufs[0])
   409  	}
   410  }
   411  
   412  func TestWSReadCompressedFrames(t *testing.T) {
   413  	c, ri, tr := testWSSetupForRead()
   414  	uncompressed := []byte("this is the uncompress data")
   415  	wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed)
   416  	rb := append([]byte(nil), wsmsg1...)
   417  	// Call with some but not all of the payload
   418  	bufs, err := c.wsRead(ri, tr, rb[:10])
   419  	if err != nil {
   420  		t.Fatalf("Unexpected error: %v", err)
   421  	}
   422  	if n := len(bufs); n != 0 {
   423  		t.Fatalf("Unexpected buffer returned: %v", n)
   424  	}
   425  	// Call with the rest, only then should we get the uncompressed data.
   426  	bufs, err = c.wsRead(ri, tr, rb[10:])
   427  	if err != nil {
   428  		t.Fatalf("Unexpected error: %v", err)
   429  	}
   430  	if n := len(bufs); n != 1 {
   431  		t.Fatalf("Unexpected buffer returned: %v", n)
   432  	}
   433  	if !bytes.Equal(bufs[0], uncompressed) {
   434  		t.Fatalf("Unexpected content: %s", bufs[0])
   435  	}
   436  	// Stress the fact that we use a pool and want to make sure
   437  	// that if we get a decompressor from the pool, it is properly reset
   438  	// with the buffer to decompress.
   439  	// Since we unmask the read buffer, reset it now and fill it
   440  	// with 10 compressed frames.
   441  	rb = nil
   442  	for i := 0; i < 10; i++ {
   443  		rb = append(rb, wsmsg1...)
   444  	}
   445  	bufs, err = c.wsRead(ri, tr, rb)
   446  	if err != nil {
   447  		t.Fatalf("Unexpected error: %v", err)
   448  	}
   449  	if n := len(bufs); n != 10 {
   450  		t.Fatalf("Unexpected buffer returned: %v", n)
   451  	}
   452  
   453  	// Compress a message and send it in several frames.
   454  	buf := &bytes.Buffer{}
   455  	compressor, _ := flate.NewWriter(buf, 1)
   456  	compressor.Write(uncompressed)
   457  	compressor.Flush()
   458  	compressed := buf.Bytes()
   459  	// The last 4 bytes are dropped
   460  	compressed = compressed[:len(compressed)-4]
   461  	ncomp := 10
   462  	frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, compressed[:ncomp])
   463  	frag1[0] |= wsRsv1Bit
   464  	frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, compressed[ncomp:])
   465  	rb = append([]byte(nil), frag1...)
   466  	rb = append(rb, frag2...)
   467  	bufs, err = c.wsRead(ri, tr, rb)
   468  	if err != nil {
   469  		t.Fatalf("Unexpected error: %v", err)
   470  	}
   471  	if n := len(bufs); n != 1 {
   472  		t.Fatalf("Unexpected buffer returned: %v", n)
   473  	}
   474  	if !bytes.Equal(bufs[0], uncompressed) {
   475  		t.Fatalf("Unexpected content: %s", bufs[0])
   476  	}
   477  }
   478  
   479  func TestWSReadCompressedFrameCorrupted(t *testing.T) {
   480  	c, ri, tr := testWSSetupForRead()
   481  	uncompressed := []byte("this is the uncompress data")
   482  	wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed)
   483  	copy(wsmsg1[10:], []byte{1, 2, 3, 4})
   484  	rb := append([]byte(nil), wsmsg1...)
   485  	bufs, err := c.wsRead(ri, tr, rb)
   486  	if err == nil || !strings.Contains(err.Error(), "corrupt") {
   487  		t.Fatalf("Expected error about corrupted data, got %v", err)
   488  	}
   489  	if n := len(bufs); n != 0 {
   490  		t.Fatalf("Expected no buffer, got %v", n)
   491  	}
   492  }
   493  
   494  func TestWSReadVariousFrameSizes(t *testing.T) {
   495  	for _, test := range []struct {
   496  		name string
   497  		size int
   498  	}{
   499  		{"tiny", 100},
   500  		{"medium", 1000},
   501  		{"large", 70000},
   502  	} {
   503  		t.Run(test.name, func(t *testing.T) {
   504  			c, ri, tr := testWSSetupForRead()
   505  			uncompressed := make([]byte, test.size)
   506  			for i := 0; i < len(uncompressed); i++ {
   507  				uncompressed[i] = 'A' + byte(i%26)
   508  			}
   509  			wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, uncompressed)
   510  			rb := append([]byte(nil), wsmsg1...)
   511  			bufs, err := c.wsRead(ri, tr, rb)
   512  			if err != nil {
   513  				t.Fatalf("Unexpected error: %v", err)
   514  			}
   515  			if n := len(bufs); n != 1 {
   516  				t.Fatalf("Unexpected buffer returned: %v", n)
   517  			}
   518  			if !bytes.Equal(bufs[0], uncompressed) {
   519  				t.Fatalf("Unexpected content: %s", bufs[0])
   520  			}
   521  		})
   522  	}
   523  }
   524  
   525  func TestWSReadFragmentedFrames(t *testing.T) {
   526  	c, ri, tr := testWSSetupForRead()
   527  	payloads := []string{"first", "second", "third"}
   528  	var rb []byte
   529  	for i := 0; i < len(payloads); i++ {
   530  		final := i == len(payloads)-1
   531  		frag := testWSCreateClientMsg(wsBinaryMessage, i+1, final, false, []byte(payloads[i]))
   532  		rb = append(rb, frag...)
   533  	}
   534  	bufs, err := c.wsRead(ri, tr, rb)
   535  	if err != nil {
   536  		t.Fatalf("Unexpected error: %v", err)
   537  	}
   538  	if n := len(bufs); n != 3 {
   539  		t.Fatalf("Unexpected buffer returned: %v", n)
   540  	}
   541  	for i, expected := range payloads {
   542  		if string(bufs[i]) != expected {
   543  			t.Fatalf("Unexpected content for buf=%v: %s", i, bufs[i])
   544  		}
   545  	}
   546  }
   547  
   548  func TestWSReadPartialFrameHeaderAtEndOfReadBuffer(t *testing.T) {
   549  	c, ri, tr := testWSSetupForRead()
   550  	msg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg1"))
   551  	msg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg2"))
   552  	rb := append([]byte(nil), msg1...)
   553  	rb = append(rb, msg2...)
   554  	// We will pass the first frame + the first byte of the next frame.
   555  	rbl := rb[:len(msg1)+1]
   556  	// Make the io reader return the rest of the frame
   557  	tr.buf = rb[len(msg1)+1:]
   558  	bufs, err := c.wsRead(ri, tr, rbl)
   559  	if err != nil {
   560  		t.Fatalf("Unexpected error: %v", err)
   561  	}
   562  	if n := len(bufs); n != 1 {
   563  		t.Fatalf("Unexpected buffer returned: %v", n)
   564  	}
   565  	// We should not have asked to the io reader more than what is needed for reading
   566  	// the frame header. Since we had already the first byte in the read buffer,
   567  	// tr.pos should be 1(size)+4(key)=5
   568  	if tr.pos != 5 {
   569  		t.Fatalf("Expected reader pos to be 5, got %v", tr.pos)
   570  	}
   571  }
   572  
   573  func TestWSReadPingFrame(t *testing.T) {
   574  	for _, test := range []struct {
   575  		name    string
   576  		payload []byte
   577  	}{
   578  		{"without payload", nil},
   579  		{"with payload", []byte("optional payload")},
   580  	} {
   581  		t.Run(test.name, func(t *testing.T) {
   582  			c, ri, tr := testWSSetupForRead()
   583  			ping := testWSCreateClientMsg(wsPingMessage, 1, true, false, test.payload)
   584  			rb := append([]byte(nil), ping...)
   585  			bufs, err := c.wsRead(ri, tr, rb)
   586  			if err != nil {
   587  				t.Fatalf("Unexpected error: %v", err)
   588  			}
   589  			if n := len(bufs); n != 0 {
   590  				t.Fatalf("Unexpected buffer returned: %v", n)
   591  			}
   592  			// A PONG should have been queued with the payload of the ping
   593  			c.mu.Lock()
   594  			nb, _ := c.collapsePtoNB()
   595  			c.mu.Unlock()
   596  			if n := len(nb); n == 0 {
   597  				t.Fatalf("Expected buffers, got %v", n)
   598  			}
   599  			if expected := 2 + len(test.payload); expected != len(nb[0]) {
   600  				t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
   601  			}
   602  			b := nb[0][0]
   603  			if b&wsFinalBit == 0 {
   604  				t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
   605  			}
   606  			if b&byte(wsPongMessage) == 0 {
   607  				t.Fatalf("Should have been a PONG, it wasn't: %v", b)
   608  			}
   609  			if len(test.payload) > 0 {
   610  				if !bytes.Equal(nb[0][2:], test.payload) {
   611  					t.Fatalf("Unexpected content: %s", nb[0][2:])
   612  				}
   613  			}
   614  		})
   615  	}
   616  }
   617  
   618  func TestWSReadPongFrame(t *testing.T) {
   619  	for _, test := range []struct {
   620  		name    string
   621  		payload []byte
   622  	}{
   623  		{"without payload", nil},
   624  		{"with payload", []byte("optional payload")},
   625  	} {
   626  		t.Run(test.name, func(t *testing.T) {
   627  			c, ri, tr := testWSSetupForRead()
   628  			pong := testWSCreateClientMsg(wsPongMessage, 1, true, false, test.payload)
   629  			rb := append([]byte(nil), pong...)
   630  			bufs, err := c.wsRead(ri, tr, rb)
   631  			if err != nil {
   632  				t.Fatalf("Unexpected error: %v", err)
   633  			}
   634  			if n := len(bufs); n != 0 {
   635  				t.Fatalf("Unexpected buffer returned: %v", n)
   636  			}
   637  			// Nothing should be sent...
   638  			c.mu.Lock()
   639  			nb, _ := c.collapsePtoNB()
   640  			c.mu.Unlock()
   641  			if n := len(nb); n != 0 {
   642  				t.Fatalf("Expected no buffer, got %v", n)
   643  			}
   644  		})
   645  	}
   646  }
   647  
   648  func TestWSReadCloseFrame(t *testing.T) {
   649  	for _, test := range []struct {
   650  		name    string
   651  		payload []byte
   652  	}{
   653  		{"without payload", nil},
   654  		{"with payload", []byte("optional payload")},
   655  	} {
   656  		t.Run(test.name, func(t *testing.T) {
   657  			c, ri, tr := testWSSetupForRead()
   658  			// a close message has a status in 2 bytes + optional payload
   659  			payload := make([]byte, 2+len(test.payload))
   660  			binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure)
   661  			if len(test.payload) > 0 {
   662  				copy(payload[2:], test.payload)
   663  			}
   664  			close := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload)
   665  			// Have a normal frame prior to close to make sure that wsRead returns
   666  			// the normal frame along with io.EOF to indicate that wsCloseMessage was received.
   667  			msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg"))
   668  			rb := append([]byte(nil), msg...)
   669  			rb = append(rb, close...)
   670  			bufs, err := c.wsRead(ri, tr, rb)
   671  			// It is expected that wsRead returns io.EOF on processing a close.
   672  			if err != io.EOF {
   673  				t.Fatalf("Unexpected error: %v", err)
   674  			}
   675  			if n := len(bufs); n != 1 {
   676  				t.Fatalf("Unexpected buffer returned: %v", n)
   677  			}
   678  			if string(bufs[0]) != "msg" {
   679  				t.Fatalf("Unexpected content: %s", bufs[0])
   680  			}
   681  			// A CLOSE should have been queued with the payload of the original close message.
   682  			c.mu.Lock()
   683  			nb, _ := c.collapsePtoNB()
   684  			c.mu.Unlock()
   685  			if n := len(nb); n == 0 {
   686  				t.Fatalf("Expected buffers, got %v", n)
   687  			}
   688  			if expected := 2 + 2 + len(test.payload); expected != len(nb[0]) {
   689  				t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
   690  			}
   691  			b := nb[0][0]
   692  			if b&wsFinalBit == 0 {
   693  				t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
   694  			}
   695  			if b&byte(wsCloseMessage) == 0 {
   696  				t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
   697  			}
   698  			if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure {
   699  				t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status)
   700  			}
   701  			if len(test.payload) > 0 {
   702  				if !bytes.Equal(nb[0][4:], test.payload) {
   703  					t.Fatalf("Unexpected content: %s", nb[0][4:])
   704  				}
   705  			}
   706  		})
   707  	}
   708  }
   709  
   710  func TestWSReadControlFrameBetweebFragmentedFrames(t *testing.T) {
   711  	c, ri, tr := testWSSetupForRead()
   712  	frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("first"))
   713  	frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, []byte("second"))
   714  	ctrl := testWSCreateClientMsg(wsPongMessage, 1, true, false, nil)
   715  	rb := append([]byte(nil), frag1...)
   716  	rb = append(rb, ctrl...)
   717  	rb = append(rb, frag2...)
   718  	bufs, err := c.wsRead(ri, tr, rb)
   719  	if err != nil {
   720  		t.Fatalf("Unexpected error: %v", err)
   721  	}
   722  	if n := len(bufs); n != 2 {
   723  		t.Fatalf("Unexpected buffer returned: %v", n)
   724  	}
   725  	if string(bufs[0]) != "first" {
   726  		t.Fatalf("Unexpected content: %s", bufs[0])
   727  	}
   728  	if string(bufs[1]) != "second" {
   729  		t.Fatalf("Unexpected content: %s", bufs[1])
   730  	}
   731  }
   732  
   733  func TestWSCloseFrameWithPartialOrInvalid(t *testing.T) {
   734  	c, ri, tr := testWSSetupForRead()
   735  	// a close message has a status in 2 bytes + optional payload
   736  	payloadTxt := []byte("hello")
   737  	payload := make([]byte, 2+len(payloadTxt))
   738  	binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure)
   739  	copy(payload[2:], payloadTxt)
   740  	closeMsg := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload)
   741  
   742  	// We will pass to wsRead a buffer of small capacity that contains
   743  	// only 1 byte.
   744  	closeFirtByte := []byte{closeMsg[0]}
   745  	// Make the io reader return the rest of the frame
   746  	tr.buf = closeMsg[1:]
   747  	bufs, err := c.wsRead(ri, tr, closeFirtByte[:])
   748  	// It is expected that wsRead returns io.EOF on processing a close.
   749  	if err != io.EOF {
   750  		t.Fatalf("Unexpected error: %v", err)
   751  	}
   752  	if n := len(bufs); n != 0 {
   753  		t.Fatalf("Unexpected buffer returned: %v", n)
   754  	}
   755  	// A CLOSE should have been queued with the payload of the original close message.
   756  	c.mu.Lock()
   757  	nb, _ := c.collapsePtoNB()
   758  	c.mu.Unlock()
   759  	if n := len(nb); n == 0 {
   760  		t.Fatalf("Expected buffers, got %v", n)
   761  	}
   762  	if expected := 2 + 2 + len(payloadTxt); expected != len(nb[0]) {
   763  		t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
   764  	}
   765  	b := nb[0][0]
   766  	if b&wsFinalBit == 0 {
   767  		t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
   768  	}
   769  	if b&byte(wsCloseMessage) == 0 {
   770  		t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
   771  	}
   772  	if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure {
   773  		t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status)
   774  	}
   775  	if !bytes.Equal(nb[0][4:], payloadTxt) {
   776  		t.Fatalf("Unexpected content: %s", nb[0][4:])
   777  	}
   778  
   779  	// Now test close with invalid status size (1 instead of 2 bytes)
   780  	c, ri, tr = testWSSetupForRead()
   781  	payload[0] = 100
   782  	binary.BigEndian.PutUint16(payload, wsCloseStatusNormalClosure)
   783  	closeMsg = testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload[:1])
   784  
   785  	// We will pass to wsRead a buffer of small capacity that contains
   786  	// only 1 byte.
   787  	closeFirtByte = []byte{closeMsg[0]}
   788  	// Make the io reader return the rest of the frame
   789  	tr.buf = closeMsg[1:]
   790  	bufs, err = c.wsRead(ri, tr, closeFirtByte[:])
   791  	// It is expected that wsRead returns io.EOF on processing a close.
   792  	if err != io.EOF {
   793  		t.Fatalf("Unexpected error: %v", err)
   794  	}
   795  	if n := len(bufs); n != 0 {
   796  		t.Fatalf("Unexpected buffer returned: %v", n)
   797  	}
   798  	// A CLOSE should have been queued with the payload of the original close message.
   799  	c.mu.Lock()
   800  	nb, _ = c.collapsePtoNB()
   801  	c.mu.Unlock()
   802  	if n := len(nb); n == 0 {
   803  		t.Fatalf("Expected buffers, got %v", n)
   804  	}
   805  	if expected := 2 + 2; expected != len(nb[0]) {
   806  		t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
   807  	}
   808  	b = nb[0][0]
   809  	if b&wsFinalBit == 0 {
   810  		t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
   811  	}
   812  	if b&byte(wsCloseMessage) == 0 {
   813  		t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
   814  	}
   815  	// Since satus was not valid, we should get wsCloseStatusNoStatusReceived
   816  	if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNoStatusReceived {
   817  		t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNoStatusReceived, status)
   818  	}
   819  	if len(nb[0][:]) != 4 {
   820  		t.Fatalf("Unexpected content: %s", nb[0][2:])
   821  	}
   822  }
   823  
   824  func TestWSReadGetErrors(t *testing.T) {
   825  	tr := &testReader{err: fmt.Errorf("on purpose")}
   826  	for _, test := range []struct {
   827  		lenPayload int
   828  		rbextra    int
   829  	}{
   830  		{10, 1},
   831  		{10, 3},
   832  		{200, 1},
   833  		{200, 2},
   834  		{200, 5},
   835  		{70000, 1},
   836  		{70000, 5},
   837  		{70000, 13},
   838  	} {
   839  		t.Run("", func(t *testing.T) {
   840  			c, ri, _ := testWSSetupForRead()
   841  			msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg"))
   842  			frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, make([]byte, test.lenPayload))
   843  			rb := append([]byte(nil), msg...)
   844  			rb = append(rb, frame...)
   845  			bufs, err := c.wsRead(ri, tr, rb[:len(msg)+test.rbextra])
   846  			if err == nil || err.Error() != "on purpose" {
   847  				t.Fatalf("Expected 'on purpose' error, got %v", err)
   848  			}
   849  			if n := len(bufs); n != 1 {
   850  				t.Fatalf("Unexpected buffer returned: %v", n)
   851  			}
   852  			if string(bufs[0]) != "msg" {
   853  				t.Fatalf("Unexpected content: %s", bufs[0])
   854  			}
   855  		})
   856  	}
   857  }
   858  
   859  func TestWSHandleControlFrameErrors(t *testing.T) {
   860  	c, ri, tr := testWSSetupForRead()
   861  	tr.err = fmt.Errorf("on purpose")
   862  
   863  	// a close message has a status in 2 bytes + optional payload
   864  	text := []byte("this is a close message")
   865  	payload := make([]byte, 2+len(text))
   866  	binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure)
   867  	copy(payload[2:], text)
   868  	ctrl := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload)
   869  
   870  	bufs, err := c.wsRead(ri, tr, ctrl[:len(ctrl)-4])
   871  	if err == nil || err.Error() != "on purpose" {
   872  		t.Fatalf("Expected 'on purpose' error, got %v", err)
   873  	}
   874  	if n := len(bufs); n != 0 {
   875  		t.Fatalf("Unexpected buffer returned: %v", n)
   876  	}
   877  
   878  	// Alter the content of close message. It is supposed to be valid utf-8.
   879  	c, ri, tr = testWSSetupForRead()
   880  	cp := append([]byte(nil), payload...)
   881  	cp[10] = 0xF1
   882  	ctrl = testWSCreateClientMsg(wsCloseMessage, 1, true, false, cp)
   883  	bufs, err = c.wsRead(ri, tr, ctrl)
   884  	// We should still receive an EOF but the message enqueued to the client
   885  	// should contain wsCloseStatusInvalidPayloadData and the error about invalid utf8
   886  	if err != io.EOF {
   887  		t.Fatalf("Unexpected error: %v", err)
   888  	}
   889  	if n := len(bufs); n != 0 {
   890  		t.Fatalf("Unexpected buffer returned: %v", n)
   891  	}
   892  	c.mu.Lock()
   893  	nb, _ := c.collapsePtoNB()
   894  	c.mu.Unlock()
   895  	if n := len(nb); n == 0 {
   896  		t.Fatalf("Expected buffers, got %v", n)
   897  	}
   898  	b := nb[0][0]
   899  	if b&wsFinalBit == 0 {
   900  		t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
   901  	}
   902  	if b&byte(wsCloseMessage) == 0 {
   903  		t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
   904  	}
   905  	if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusInvalidPayloadData {
   906  		t.Fatalf("Expected status to be %v, got %v", wsCloseStatusInvalidPayloadData, status)
   907  	}
   908  	if !bytes.Contains(nb[0][4:], []byte("utf8")) {
   909  		t.Fatalf("Unexpected content: %s", nb[0][4:])
   910  	}
   911  }
   912  
   913  func TestWSReadErrors(t *testing.T) {
   914  	for _, test := range []struct {
   915  		cframe func() []byte
   916  		err    string
   917  		nbufs  int
   918  	}{
   919  		{
   920  			func() []byte {
   921  				msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello"))
   922  				msg[1] &= ^byte(wsMaskBit)
   923  				return msg
   924  			},
   925  			"mask bit missing", 1,
   926  		},
   927  		{
   928  			func() []byte {
   929  				return testWSCreateClientMsg(wsPingMessage, 1, true, false, make([]byte, 200))
   930  			},
   931  			"control frame length bigger than maximum allowed", 1,
   932  		},
   933  		{
   934  			func() []byte {
   935  				return testWSCreateClientMsg(wsPingMessage, 1, false, false, []byte("hello"))
   936  			},
   937  			"control frame does not have final bit set", 1,
   938  		},
   939  		{
   940  			func() []byte {
   941  				frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("frag1"))
   942  				newMsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("new message"))
   943  				all := append([]byte(nil), frag1...)
   944  				all = append(all, newMsg...)
   945  				return all
   946  			},
   947  			"new message started before final frame for previous message was received", 2,
   948  		},
   949  		{
   950  			func() []byte {
   951  				frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("frame"))
   952  				frag := testWSCreateClientMsg(wsBinaryMessage, 2, false, false, []byte("continuation"))
   953  				all := append([]byte(nil), frame...)
   954  				all = append(all, frag...)
   955  				return all
   956  			},
   957  			"invalid continuation frame", 2,
   958  		},
   959  		{
   960  			func() []byte {
   961  				return testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frame"))
   962  			},
   963  			"invalid continuation frame", 1,
   964  		},
   965  		{
   966  			func() []byte {
   967  				return testWSCreateClientMsg(99, 1, false, false, []byte("hello"))
   968  			},
   969  			"unknown opcode", 1,
   970  		},
   971  	} {
   972  		t.Run(test.err, func(t *testing.T) {
   973  			c, ri, tr := testWSSetupForRead()
   974  			// Add a valid message first
   975  			msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello"))
   976  			// Then add the bad frame
   977  			bad := test.cframe()
   978  			// Add them both to a read buffer
   979  			rb := append([]byte(nil), msg...)
   980  			rb = append(rb, bad...)
   981  			bufs, err := c.wsRead(ri, tr, rb)
   982  			if err == nil || !strings.Contains(err.Error(), test.err) {
   983  				t.Fatalf("Expected error to contain %q, got %q", test.err, err.Error())
   984  			}
   985  			if n := len(bufs); n != test.nbufs {
   986  				t.Fatalf("Unexpected number of buffers: %v", n)
   987  			}
   988  			if string(bufs[0]) != "hello" {
   989  				t.Fatalf("Unexpected content: %s", bufs[0])
   990  			}
   991  		})
   992  	}
   993  }
   994  
   995  func TestWSEnqueueCloseMsg(t *testing.T) {
   996  	for _, test := range []struct {
   997  		reason ClosedState
   998  		status int
   999  	}{
  1000  		{ClientClosed, wsCloseStatusNormalClosure},
  1001  		{AuthenticationTimeout, wsCloseStatusPolicyViolation},
  1002  		{AuthenticationViolation, wsCloseStatusPolicyViolation},
  1003  		{SlowConsumerPendingBytes, wsCloseStatusPolicyViolation},
  1004  		{SlowConsumerWriteDeadline, wsCloseStatusPolicyViolation},
  1005  		{MaxAccountConnectionsExceeded, wsCloseStatusPolicyViolation},
  1006  		{MaxConnectionsExceeded, wsCloseStatusPolicyViolation},
  1007  		{MaxControlLineExceeded, wsCloseStatusPolicyViolation},
  1008  		{MaxSubscriptionsExceeded, wsCloseStatusPolicyViolation},
  1009  		{MissingAccount, wsCloseStatusPolicyViolation},
  1010  		{AuthenticationExpired, wsCloseStatusPolicyViolation},
  1011  		{Revocation, wsCloseStatusPolicyViolation},
  1012  		{TLSHandshakeError, wsCloseStatusTLSHandshake},
  1013  		{ParseError, wsCloseStatusProtocolError},
  1014  		{ProtocolViolation, wsCloseStatusProtocolError},
  1015  		{BadClientProtocolVersion, wsCloseStatusProtocolError},
  1016  		{MaxPayloadExceeded, wsCloseStatusMessageTooBig},
  1017  		{ServerShutdown, wsCloseStatusGoingAway},
  1018  		{WriteError, wsCloseStatusAbnormalClosure},
  1019  		{ReadError, wsCloseStatusAbnormalClosure},
  1020  		{StaleConnection, wsCloseStatusAbnormalClosure},
  1021  		{ClosedState(254), wsCloseStatusInternalSrvError},
  1022  	} {
  1023  		t.Run(test.reason.String(), func(t *testing.T) {
  1024  			c, _, _ := testWSSetupForRead()
  1025  			c.wsEnqueueCloseMessage(test.reason)
  1026  			c.mu.Lock()
  1027  			nb, _ := c.collapsePtoNB()
  1028  			c.mu.Unlock()
  1029  			if n := len(nb); n != 1 {
  1030  				t.Fatalf("Expected 1 buffer, got %v", n)
  1031  			}
  1032  			b := nb[0][0]
  1033  			if b&wsFinalBit == 0 {
  1034  				t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
  1035  			}
  1036  			if b&byte(wsCloseMessage) == 0 {
  1037  				t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
  1038  			}
  1039  			if status := binary.BigEndian.Uint16(nb[0][2:4]); int(status) != test.status {
  1040  				t.Fatalf("Expected status to be %v, got %v", test.status, status)
  1041  			}
  1042  			if string(nb[0][4:]) != test.reason.String() {
  1043  				t.Fatalf("Unexpected content: %s", nb[0][4:])
  1044  			}
  1045  		})
  1046  	}
  1047  }
  1048  
  1049  type testResponseWriter struct {
  1050  	http.ResponseWriter
  1051  	buf     bytes.Buffer
  1052  	headers http.Header
  1053  	err     error
  1054  	brw     *bufio.ReadWriter
  1055  	conn    *testWSFakeNetConn
  1056  }
  1057  
  1058  func (trw *testResponseWriter) Write(p []byte) (int, error) {
  1059  	return trw.buf.Write(p)
  1060  }
  1061  
  1062  func (trw *testResponseWriter) WriteHeader(status int) {
  1063  	trw.buf.WriteString(fmt.Sprintf("%v", status))
  1064  }
  1065  
  1066  func (trw *testResponseWriter) Header() http.Header {
  1067  	if trw.headers == nil {
  1068  		trw.headers = make(http.Header)
  1069  	}
  1070  	return trw.headers
  1071  }
  1072  
  1073  type testWSFakeNetConn struct {
  1074  	net.Conn
  1075  	wbuf            bytes.Buffer
  1076  	err             error
  1077  	wsOpened        bool
  1078  	isClosed        bool
  1079  	deadlineCleared bool
  1080  }
  1081  
  1082  func (c *testWSFakeNetConn) Write(p []byte) (int, error) {
  1083  	if c.err != nil {
  1084  		return 0, c.err
  1085  	}
  1086  	return c.wbuf.Write(p)
  1087  }
  1088  
  1089  func (c *testWSFakeNetConn) SetDeadline(t time.Time) error {
  1090  	if t.IsZero() {
  1091  		c.deadlineCleared = true
  1092  	}
  1093  	return nil
  1094  }
  1095  
  1096  func (c *testWSFakeNetConn) Close() error {
  1097  	c.isClosed = true
  1098  	return nil
  1099  }
  1100  
  1101  func (trw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  1102  	if trw.conn == nil {
  1103  		trw.conn = &testWSFakeNetConn{}
  1104  	}
  1105  	trw.conn.wsOpened = true
  1106  	if trw.brw == nil {
  1107  		trw.brw = bufio.NewReadWriter(bufio.NewReader(trw.conn), bufio.NewWriter(trw.conn))
  1108  	}
  1109  	return trw.conn, trw.brw, trw.err
  1110  }
  1111  
  1112  func testWSOptions() *Options {
  1113  	opts := DefaultOptions()
  1114  	opts.DisableShortFirstPing = true
  1115  	opts.Websocket.Host = "127.0.0.1"
  1116  	opts.Websocket.Port = -1
  1117  	opts.NoSystemAccount = true
  1118  	var err error
  1119  	tc := &TLSConfigOpts{
  1120  		CertFile: "./configs/certs/server.pem",
  1121  		KeyFile:  "./configs/certs/key.pem",
  1122  	}
  1123  	opts.Websocket.TLSConfig, err = GenTLSConfig(tc)
  1124  	if err != nil {
  1125  		panic(err)
  1126  	}
  1127  	return opts
  1128  }
  1129  
  1130  func testWSCreateValidReq() *http.Request {
  1131  	req := &http.Request{
  1132  		Method: "GET",
  1133  		Host:   "localhost",
  1134  		Proto:  "HTTP/1.1",
  1135  	}
  1136  	req.Header = make(http.Header)
  1137  	req.Header.Set("Upgrade", "websocket")
  1138  	req.Header.Set("Connection", "Upgrade")
  1139  	req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
  1140  	req.Header.Set("Sec-Websocket-Version", "13")
  1141  	return req
  1142  }
  1143  
  1144  func TestWSCheckOrigin(t *testing.T) {
  1145  	notSameOrigin := false
  1146  	sameOrigin := true
  1147  	allowedListEmpty := []string{}
  1148  	someList := []string{"http://host1.com", "http://host2.com:1234"}
  1149  
  1150  	for _, test := range []struct {
  1151  		name       string
  1152  		sameOrigin bool
  1153  		origins    []string
  1154  		reqHost    string
  1155  		reqTLS     bool
  1156  		origin     string
  1157  		err        string
  1158  	}{
  1159  		{"any", notSameOrigin, allowedListEmpty, "", false, "http://any.host.com", ""},
  1160  		{"same origin ok", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:80", ""},
  1161  		{"same origin bad host", sameOrigin, allowedListEmpty, "host.com", false, "http://other.host.com", "not same origin"},
  1162  		{"same origin bad port", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:81", "not same origin"},
  1163  		{"same origin bad scheme", sameOrigin, allowedListEmpty, "host.com", true, "http://host.com", "not same origin"},
  1164  		{"same origin bad uri", sameOrigin, allowedListEmpty, "host.com", false, "@@@://invalid:url:1234", "invalid URI"},
  1165  		{"same origin bad url", sameOrigin, allowedListEmpty, "host.com", false, "http://invalid:url:1234", "too many colons"},
  1166  		{"same origin bad req host", sameOrigin, allowedListEmpty, "invalid:url:1234", false, "http://host.com", "too many colons"},
  1167  		{"no origin same origin ignored", sameOrigin, allowedListEmpty, "", false, "", ""},
  1168  		{"no origin list ignored", sameOrigin, someList, "", false, "", ""},
  1169  		{"no origin same origin and list ignored", sameOrigin, someList, "", false, "", ""},
  1170  		{"allowed from list", notSameOrigin, someList, "", false, "http://host2.com:1234", ""},
  1171  		{"allowed with different path", notSameOrigin, someList, "", false, "http://host1.com/some/path", ""},
  1172  		{"list bad port", notSameOrigin, someList, "", false, "http://host1.com:1234", "not in the allowed list"},
  1173  		{"list bad scheme", notSameOrigin, someList, "", false, "https://host2.com:1234", "not in the allowed list"},
  1174  	} {
  1175  		t.Run(test.name, func(t *testing.T) {
  1176  			opts := DefaultOptions()
  1177  			opts.Websocket.SameOrigin = test.sameOrigin
  1178  			opts.Websocket.AllowedOrigins = test.origins
  1179  			s := &Server{opts: opts}
  1180  			s.wsSetOriginOptions(&opts.Websocket)
  1181  
  1182  			req := testWSCreateValidReq()
  1183  			req.Host = test.reqHost
  1184  			if test.reqTLS {
  1185  				req.TLS = &tls.ConnectionState{}
  1186  			}
  1187  			if test.origin != "" {
  1188  				req.Header.Set("Origin", test.origin)
  1189  			}
  1190  			err := s.websocket.checkOrigin(req)
  1191  			if test.err == "" && err != nil {
  1192  				t.Fatalf("Unexpected error: %v", err)
  1193  			} else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) {
  1194  				t.Fatalf("Expected error %q, got %v", test.err, err)
  1195  			}
  1196  		})
  1197  	}
  1198  }
  1199  
  1200  func TestWSUpgradeValidationErrors(t *testing.T) {
  1201  	for _, test := range []struct {
  1202  		name   string
  1203  		setup  func() (*Options, *testResponseWriter, *http.Request)
  1204  		err    string
  1205  		status int
  1206  	}{
  1207  		{
  1208  			"bad method",
  1209  			func() (*Options, *testResponseWriter, *http.Request) {
  1210  				opts := testWSOptions()
  1211  				req := testWSCreateValidReq()
  1212  				req.Method = "POST"
  1213  				return opts, nil, req
  1214  			},
  1215  			"must be GET",
  1216  			http.StatusMethodNotAllowed,
  1217  		},
  1218  		{
  1219  			"no host",
  1220  			func() (*Options, *testResponseWriter, *http.Request) {
  1221  				opts := testWSOptions()
  1222  				req := testWSCreateValidReq()
  1223  				req.Host = ""
  1224  				return opts, nil, req
  1225  			},
  1226  			"'Host' missing in request",
  1227  			http.StatusBadRequest,
  1228  		},
  1229  		{
  1230  			"invalid upgrade header",
  1231  			func() (*Options, *testResponseWriter, *http.Request) {
  1232  				opts := testWSOptions()
  1233  				req := testWSCreateValidReq()
  1234  				req.Header.Del("Upgrade")
  1235  				return opts, nil, req
  1236  			},
  1237  			"invalid value for header 'Upgrade'",
  1238  			http.StatusBadRequest,
  1239  		},
  1240  		{
  1241  			"invalid connection header",
  1242  			func() (*Options, *testResponseWriter, *http.Request) {
  1243  				opts := testWSOptions()
  1244  				req := testWSCreateValidReq()
  1245  				req.Header.Del("Connection")
  1246  				return opts, nil, req
  1247  			},
  1248  			"invalid value for header 'Connection'",
  1249  			http.StatusBadRequest,
  1250  		},
  1251  		{
  1252  			"no key",
  1253  			func() (*Options, *testResponseWriter, *http.Request) {
  1254  				opts := testWSOptions()
  1255  				req := testWSCreateValidReq()
  1256  				req.Header.Del("Sec-Websocket-Key")
  1257  				return opts, nil, req
  1258  			},
  1259  			"key missing",
  1260  			http.StatusBadRequest,
  1261  		},
  1262  		{
  1263  			"empty key",
  1264  			func() (*Options, *testResponseWriter, *http.Request) {
  1265  				opts := testWSOptions()
  1266  				req := testWSCreateValidReq()
  1267  				req.Header.Set("Sec-Websocket-Key", "")
  1268  				return opts, nil, req
  1269  			},
  1270  			"key missing",
  1271  			http.StatusBadRequest,
  1272  		},
  1273  		{
  1274  			"missing version",
  1275  			func() (*Options, *testResponseWriter, *http.Request) {
  1276  				opts := testWSOptions()
  1277  				req := testWSCreateValidReq()
  1278  				req.Header.Del("Sec-Websocket-Version")
  1279  				return opts, nil, req
  1280  			},
  1281  			"invalid version",
  1282  			http.StatusBadRequest,
  1283  		},
  1284  		{
  1285  			"wrong version",
  1286  			func() (*Options, *testResponseWriter, *http.Request) {
  1287  				opts := testWSOptions()
  1288  				req := testWSCreateValidReq()
  1289  				req.Header.Set("Sec-Websocket-Version", "99")
  1290  				return opts, nil, req
  1291  			},
  1292  			"invalid version",
  1293  			http.StatusBadRequest,
  1294  		},
  1295  		{
  1296  			"origin",
  1297  			func() (*Options, *testResponseWriter, *http.Request) {
  1298  				opts := testWSOptions()
  1299  				opts.Websocket.SameOrigin = true
  1300  				req := testWSCreateValidReq()
  1301  				req.Header.Set("Origin", "http://bad.host.com")
  1302  				return opts, nil, req
  1303  			},
  1304  			"origin not allowed",
  1305  			http.StatusForbidden,
  1306  		},
  1307  		{
  1308  			"hijack error",
  1309  			func() (*Options, *testResponseWriter, *http.Request) {
  1310  				opts := testWSOptions()
  1311  				rw := &testResponseWriter{err: fmt.Errorf("on purpose")}
  1312  				req := testWSCreateValidReq()
  1313  				return opts, rw, req
  1314  			},
  1315  			"on purpose",
  1316  			http.StatusInternalServerError,
  1317  		},
  1318  		{
  1319  			"hijack buffered data",
  1320  			func() (*Options, *testResponseWriter, *http.Request) {
  1321  				opts := testWSOptions()
  1322  				buf := &bytes.Buffer{}
  1323  				buf.WriteString("some data")
  1324  				rw := &testResponseWriter{
  1325  					conn: &testWSFakeNetConn{},
  1326  					brw:  bufio.NewReadWriter(bufio.NewReader(buf), bufio.NewWriter(nil)),
  1327  				}
  1328  				tmp := [1]byte{}
  1329  				io.ReadAtLeast(rw.brw, tmp[:1], 1)
  1330  				req := testWSCreateValidReq()
  1331  				return opts, rw, req
  1332  			},
  1333  			"client sent data before handshake is complete",
  1334  			http.StatusBadRequest,
  1335  		},
  1336  	} {
  1337  		t.Run(test.name, func(t *testing.T) {
  1338  			opts, rw, req := test.setup()
  1339  			if rw == nil {
  1340  				rw = &testResponseWriter{}
  1341  			}
  1342  			s := &Server{opts: opts}
  1343  			s.wsSetOriginOptions(&opts.Websocket)
  1344  			res, err := s.wsUpgrade(rw, req)
  1345  			if err == nil || !strings.Contains(err.Error(), test.err) {
  1346  				t.Fatalf("Should get error %q, got %v", test.err, err)
  1347  			}
  1348  			if res != nil {
  1349  				t.Fatalf("Should not have returned a result, got %v", res)
  1350  			}
  1351  			expected := fmt.Sprintf("%v%s\n", test.status, http.StatusText(test.status))
  1352  			if got := rw.buf.String(); got != expected {
  1353  				t.Fatalf("Expected %q got %q", expected, got)
  1354  			}
  1355  			// Check that if the connection was opened, it is now closed.
  1356  			if rw.conn != nil && rw.conn.wsOpened && !rw.conn.isClosed {
  1357  				t.Fatal("Connection was opened, but has not been closed")
  1358  			}
  1359  		})
  1360  	}
  1361  }
  1362  
  1363  func TestWSUpgradeResponseWriteError(t *testing.T) {
  1364  	opts := testWSOptions()
  1365  	s := &Server{opts: opts}
  1366  	expectedErr := errors.New("on purpose")
  1367  	rw := &testResponseWriter{
  1368  		conn: &testWSFakeNetConn{err: expectedErr},
  1369  	}
  1370  	req := testWSCreateValidReq()
  1371  	res, err := s.wsUpgrade(rw, req)
  1372  	if err != expectedErr {
  1373  		t.Fatalf("Should get error %q, got %v", expectedErr.Error(), err)
  1374  	}
  1375  	if res != nil {
  1376  		t.Fatalf("Should not have returned a result, got %v", res)
  1377  	}
  1378  	if !rw.conn.isClosed {
  1379  		t.Fatal("Connection should have been closed")
  1380  	}
  1381  }
  1382  
  1383  func TestWSUpgradeConnDeadline(t *testing.T) {
  1384  	opts := testWSOptions()
  1385  	opts.Websocket.HandshakeTimeout = time.Second
  1386  	s := &Server{opts: opts}
  1387  	rw := &testResponseWriter{}
  1388  	req := testWSCreateValidReq()
  1389  	res, err := s.wsUpgrade(rw, req)
  1390  	if res == nil || err != nil {
  1391  		t.Fatalf("Unexpected error: %v", err)
  1392  	}
  1393  	if rw.conn.isClosed {
  1394  		t.Fatal("Connection should NOT have been closed")
  1395  	}
  1396  	if !rw.conn.deadlineCleared {
  1397  		t.Fatal("Connection deadline should have been cleared after handshake")
  1398  	}
  1399  }
  1400  
  1401  func TestWSCompressNegotiation(t *testing.T) {
  1402  	// No compression on the server, but client asks
  1403  	opts := testWSOptions()
  1404  	s := &Server{opts: opts}
  1405  	rw := &testResponseWriter{}
  1406  	req := testWSCreateValidReq()
  1407  	req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate")
  1408  	res, err := s.wsUpgrade(rw, req)
  1409  	if res == nil || err != nil {
  1410  		t.Fatalf("Unexpected error: %v", err)
  1411  	}
  1412  	// The http response should not contain "permessage-deflate"
  1413  	output := rw.conn.wbuf.String()
  1414  	if strings.Contains(output, "permessage-deflate") {
  1415  		t.Fatalf("Compression disabled in server so response to client should not contain extension, got %s", output)
  1416  	}
  1417  
  1418  	// Option in the server and client, so compression should be negotiated.
  1419  	s.opts.Websocket.Compression = true
  1420  	rw = &testResponseWriter{}
  1421  	res, err = s.wsUpgrade(rw, req)
  1422  	if res == nil || err != nil {
  1423  		t.Fatalf("Unexpected error: %v", err)
  1424  	}
  1425  	// The http response should not contain "permessage-deflate"
  1426  	output = rw.conn.wbuf.String()
  1427  	if !strings.Contains(output, "permessage-deflate") {
  1428  		t.Fatalf("Compression in server and client request, so response should contain extension, got %s", output)
  1429  	}
  1430  
  1431  	// Option in server but not asked by the client, so response should not contain "permessage-deflate"
  1432  	rw = &testResponseWriter{}
  1433  	req.Header.Del("Sec-Websocket-Extensions")
  1434  	res, err = s.wsUpgrade(rw, req)
  1435  	if res == nil || err != nil {
  1436  		t.Fatalf("Unexpected error: %v", err)
  1437  	}
  1438  	// The http response should not contain "permessage-deflate"
  1439  	output = rw.conn.wbuf.String()
  1440  	if strings.Contains(output, "permessage-deflate") {
  1441  		t.Fatalf("Compression in server but not in client, so response to client should not contain extension, got %s", output)
  1442  	}
  1443  }
  1444  
  1445  func TestWSSetHeader(t *testing.T) {
  1446  	opts := testWSOptions()
  1447  	opts.Websocket.Headers = map[string]string{
  1448  		"X-Header":         "some-value",
  1449  		"X-Another-Header": "another-value",
  1450  	}
  1451  	s := &Server{opts: opts}
  1452  	s.wsSetHeadersOptions(&opts.Websocket)
  1453  	rw := &testResponseWriter{}
  1454  	req := testWSCreateValidReq()
  1455  	res, err := s.wsUpgrade(rw, req)
  1456  	if res == nil || err != nil {
  1457  		t.Fatalf("Unexpected error: %v", err)
  1458  	}
  1459  
  1460  	buf := bufio.NewReader(&rw.conn.wbuf)
  1461  	resp, err := http.ReadResponse(buf, req)
  1462  	if err != nil {
  1463  		t.Fatalf("Error reading request: %v", err)
  1464  	}
  1465  	defer resp.Body.Close()
  1466  
  1467  	// Check that the response is a 101
  1468  	if resp.StatusCode != http.StatusSwitchingProtocols {
  1469  		t.Fatalf("Expected 101, got %v", resp.StatusCode)
  1470  	}
  1471  
  1472  	headers := resp.Header.Clone()
  1473  
  1474  	// Compare all the headers
  1475  	for k, v := range opts.Websocket.Headers {
  1476  		if got := headers.Get(k); got != v {
  1477  			t.Fatalf("Expected %q for header %q, got %q", v, k, got)
  1478  		}
  1479  		headers.Del(k)
  1480  	}
  1481  
  1482  	// Check remain headers
  1483  	for k, v := range map[string]string{
  1484  		"Upgrade":              "websocket",
  1485  		"Connection":           "Upgrade",
  1486  		"Sec-Websocket-Accept": wsAcceptKey(req.Header.Get("Sec-Websocket-Key")),
  1487  	} {
  1488  		if got := headers.Get(k); got != v {
  1489  			t.Fatalf("Expected %q for header %q, got %q", v, k, got)
  1490  		}
  1491  		headers.Del(k)
  1492  	}
  1493  
  1494  	// Check that we have no more headers
  1495  	if len(headers) > 0 {
  1496  		t.Fatalf("Unexpected headers: %v", headers)
  1497  	}
  1498  }
  1499  
  1500  func TestWSParseOptions(t *testing.T) {
  1501  	for _, test := range []struct {
  1502  		name     string
  1503  		content  string
  1504  		checkOpt func(*WebsocketOpts) error
  1505  		err      string
  1506  	}{
  1507  		// Negative tests
  1508  		{"bad type", "websocket: []", nil, "to be a map"},
  1509  		{"bad listen", "websocket: { listen: [] }", nil, "port or host:port"},
  1510  		{"bad port", `websocket: { port: "abc" }`, nil, "not int64"},
  1511  		{"bad host", `websocket: { host: 123 }`, nil, "not string"},
  1512  		{"bad advertise type", `websocket: { advertise: 123 }`, nil, "not string"},
  1513  		{"bad tls", `websocket: { tls: 123 }`, nil, "not map[string]interface {}"},
  1514  		{"bad same origin", `websocket: { same_origin: "abc" }`, nil, "not bool"},
  1515  		{"bad allowed origins type", `websocket: { allowed_origins: {} }`, nil, "unsupported type"},
  1516  		{"bad allowed origins values", `websocket: { allowed_origins: [ {} ] }`, nil, "unsupported type in array"},
  1517  		{"bad handshake timeout type", `websocket: { handshake_timeout: [] }`, nil, "unsupported type"},
  1518  		{"bad handshake timeout duration", `websocket: { handshake_timeout: "abc" }`, nil, "invalid duration"},
  1519  		{"bad header type", `websocket: { headers: 123 }`, nil, "unsupported type"},
  1520  		{"bad header type", `websocket: { headers: [] }`, nil, "unsupported type"},
  1521  		{"bad header value", `websocket: { headers: { "key": 123 } }`, nil, "unsupported type"},
  1522  		{"unknown field", `websocket: { this_does_not_exist: 123 }`, nil, "unknown"},
  1523  		// Positive tests
  1524  		{"listen port only", `websocket { listen: 1234 }`, func(wo *WebsocketOpts) error {
  1525  			if wo.Port != 1234 {
  1526  				return fmt.Errorf("expected 1234, got %v", wo.Port)
  1527  			}
  1528  			return nil
  1529  		}, ""},
  1530  		{"listen host and port", `websocket { listen: "localhost:1234" }`, func(wo *WebsocketOpts) error {
  1531  			if wo.Host != "localhost" || wo.Port != 1234 {
  1532  				return fmt.Errorf("expected localhost:1234, got %v:%v", wo.Host, wo.Port)
  1533  			}
  1534  			return nil
  1535  		}, ""},
  1536  		{"host", `websocket { host: "localhost" }`, func(wo *WebsocketOpts) error {
  1537  			if wo.Host != "localhost" {
  1538  				return fmt.Errorf("expected localhost, got %v", wo.Host)
  1539  			}
  1540  			return nil
  1541  		}, ""},
  1542  		{"port", `websocket { port: 1234 }`, func(wo *WebsocketOpts) error {
  1543  			if wo.Port != 1234 {
  1544  				return fmt.Errorf("expected 1234, got %v", wo.Port)
  1545  			}
  1546  			return nil
  1547  		}, ""},
  1548  		{"advertise", `websocket { advertise: "host:1234" }`, func(wo *WebsocketOpts) error {
  1549  			if wo.Advertise != "host:1234" {
  1550  				return fmt.Errorf("expected %q, got %q", "host:1234", wo.Advertise)
  1551  			}
  1552  			return nil
  1553  		}, ""},
  1554  		{"same origin", `websocket { same_origin: true }`, func(wo *WebsocketOpts) error {
  1555  			if !wo.SameOrigin {
  1556  				return fmt.Errorf("expected same_origin==true, got %v", wo.SameOrigin)
  1557  			}
  1558  			return nil
  1559  		}, ""},
  1560  		{"allowed origins one only", `websocket { allowed_origins: "https://host.com/" }`, func(wo *WebsocketOpts) error {
  1561  			expected := []string{"https://host.com/"}
  1562  			if !reflect.DeepEqual(wo.AllowedOrigins, expected) {
  1563  				return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins)
  1564  			}
  1565  			return nil
  1566  		}, ""},
  1567  		{"allowed origins array",
  1568  			`
  1569  			websocket {
  1570  				allowed_origins: [
  1571  					"https://host1.com/"
  1572  					"https://host2.com/"
  1573  				]
  1574  			}
  1575  			`, func(wo *WebsocketOpts) error {
  1576  				expected := []string{"https://host1.com/", "https://host2.com/"}
  1577  				if !reflect.DeepEqual(wo.AllowedOrigins, expected) {
  1578  					return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins)
  1579  				}
  1580  				return nil
  1581  			}, ""},
  1582  		{"handshake timeout in whole seconds", `websocket { handshake_timeout: 3 }`, func(wo *WebsocketOpts) error {
  1583  			if wo.HandshakeTimeout != 3*time.Second {
  1584  				return fmt.Errorf("expected handshake to be 3s, got %v", wo.HandshakeTimeout)
  1585  			}
  1586  			return nil
  1587  		}, ""},
  1588  		{"handshake timeout n duration", `websocket { handshake_timeout: "4s" }`, func(wo *WebsocketOpts) error {
  1589  			if wo.HandshakeTimeout != 4*time.Second {
  1590  				return fmt.Errorf("expected handshake to be 4s, got %v", wo.HandshakeTimeout)
  1591  			}
  1592  			return nil
  1593  		}, ""},
  1594  		{"tls config",
  1595  			`
  1596  			websocket {
  1597  				tls {
  1598  					cert_file: "./configs/certs/server.pem"
  1599  					key_file: "./configs/certs/key.pem"
  1600  				}
  1601  			}
  1602  			`, func(wo *WebsocketOpts) error {
  1603  				if wo.TLSConfig == nil {
  1604  					return fmt.Errorf("TLSConfig should have been set")
  1605  				}
  1606  				return nil
  1607  			}, ""},
  1608  		{"compression",
  1609  			`
  1610  			websocket {
  1611  				compression: true
  1612  			}
  1613  			`, func(wo *WebsocketOpts) error {
  1614  				if !wo.Compression {
  1615  					return fmt.Errorf("Compression should have been set")
  1616  				}
  1617  				return nil
  1618  			}, ""},
  1619  		{"jwt cookie",
  1620  			`
  1621  			websocket {
  1622  				jwt_cookie: "jwtcookie"
  1623  			}
  1624  			`, func(wo *WebsocketOpts) error {
  1625  				if wo.JWTCookie != "jwtcookie" {
  1626  					return fmt.Errorf("Invalid JWTCookie value: %q", wo.JWTCookie)
  1627  				}
  1628  				return nil
  1629  			}, ""},
  1630  		{"no auth user",
  1631  			`
  1632  			websocket {
  1633  				no_auth_user: "noauthuser"
  1634  			}
  1635  			`, func(wo *WebsocketOpts) error {
  1636  				if wo.NoAuthUser != "noauthuser" {
  1637  					return fmt.Errorf("Invalid NoAuthUser value: %q", wo.NoAuthUser)
  1638  				}
  1639  				return nil
  1640  			}, ""},
  1641  		{"auth block",
  1642  			`
  1643  			websocket {
  1644  				authorization {
  1645  					user: "webuser"
  1646  					password: "pwd"
  1647  					token: "token"
  1648  					timeout: 2.0
  1649  				}
  1650  			}
  1651  			`, func(wo *WebsocketOpts) error {
  1652  				if wo.Username != "webuser" || wo.Password != "pwd" || wo.Token != "token" || wo.AuthTimeout != 2.0 {
  1653  					return fmt.Errorf("Invalid auth block: %+v", wo)
  1654  				}
  1655  				return nil
  1656  			}, ""},
  1657  		{"auth timeout as int",
  1658  			`
  1659  			websocket {
  1660  				authorization {
  1661  					timeout: 2
  1662  				}
  1663  			}
  1664  			`, func(wo *WebsocketOpts) error {
  1665  				if wo.AuthTimeout != 2.0 {
  1666  					return fmt.Errorf("Invalid auth timeout: %v", wo.AuthTimeout)
  1667  				}
  1668  				return nil
  1669  			}, ""},
  1670  		{"headers block",
  1671  			`
  1672  			websocket {
  1673  				headers {
  1674  					"X-Header": "some-value"
  1675  					"X-Another-Header": "another-value"
  1676  				}
  1677  			}
  1678  			`, func(wo *WebsocketOpts) error {
  1679  				if len(wo.Headers) != 2 {
  1680  					return fmt.Errorf("Expected 2 headers, got %v", len(wo.Headers))
  1681  				}
  1682  
  1683  				for k, v := range map[string]string{
  1684  					"X-Header":         "some-value",
  1685  					"X-Another-Header": "another-value",
  1686  				} {
  1687  					if got, ok := wo.Headers[k]; !ok || got != v {
  1688  						return fmt.Errorf("Invalid value for %q: %q", k, got)
  1689  					}
  1690  				}
  1691  				return nil
  1692  			}, ""},
  1693  	} {
  1694  		t.Run(test.name, func(t *testing.T) {
  1695  			conf := createConfFile(t, []byte(test.content))
  1696  			o, err := ProcessConfigFile(conf)
  1697  			if test.err != _EMPTY_ {
  1698  				if err == nil || !strings.Contains(err.Error(), test.err) {
  1699  					t.Fatalf("For content: %q, expected error about %q, got %v", test.content, test.err, err)
  1700  				}
  1701  				return
  1702  			} else if err != nil {
  1703  				t.Fatalf("Unexpected error for content %q: %v", test.content, err)
  1704  			}
  1705  			if err := test.checkOpt(&o.Websocket); err != nil {
  1706  				t.Fatalf("Incorrect option for content %q: %v", test.content, err.Error())
  1707  			}
  1708  		})
  1709  	}
  1710  }
  1711  
  1712  func TestWSValidateOptions(t *testing.T) {
  1713  	nwso := DefaultOptions()
  1714  	wso := testWSOptions()
  1715  	for _, test := range []struct {
  1716  		name    string
  1717  		getOpts func() *Options
  1718  		err     string
  1719  	}{
  1720  		{"websocket disabled", func() *Options { return nwso.Clone() }, ""},
  1721  		{"no tls", func() *Options { o := wso.Clone(); o.Websocket.TLSConfig = nil; return o }, "requires TLS configuration"},
  1722  		{"bad url in allowed list", func() *Options {
  1723  			o := wso.Clone()
  1724  			o.Websocket.AllowedOrigins = []string{"http://this:is:bad:url"}
  1725  			return o
  1726  		}, "unable to parse"},
  1727  		{"missing trusted configuration", func() *Options {
  1728  			o := wso.Clone()
  1729  			o.Websocket.JWTCookie = "jwt"
  1730  			return o
  1731  		}, "keys configuration is required"},
  1732  		{"websocket username not allowed if users specified", func() *Options {
  1733  			o := wso.Clone()
  1734  			o.Nkeys = []*NkeyUser{{Nkey: "abc"}}
  1735  			o.Websocket.Username = "b"
  1736  			o.Websocket.Password = "pwd"
  1737  			return o
  1738  		}, "websocket authentication username not compatible with presence of users/nkeys"},
  1739  		{"websocket token not allowed if users specified", func() *Options {
  1740  			o := wso.Clone()
  1741  			o.Nkeys = []*NkeyUser{{Nkey: "abc"}}
  1742  			o.Websocket.Token = "mytoken"
  1743  			return o
  1744  		}, "websocket authentication token not compatible with presence of users/nkeys"},
  1745  		{"headers with sec-websocket- prefix not allowed", func() *Options {
  1746  			o := wso.Clone()
  1747  			o.Websocket.Headers = map[string]string{"Sec-WebSocket-Key": "123"}
  1748  			return o
  1749  		}, `invalid header "Sec-WebSocket-Key", "Sec-WebSocket-" prefix not allowed`},
  1750  		{"header with host", func() *Options {
  1751  			o := wso.Clone()
  1752  			o.Websocket.Headers = map[string]string{"Host": "http://localhost:8080"}
  1753  			return o
  1754  		}, `websocket: invalid header "Host" not allowed`},
  1755  		{"header with content-length", func() *Options {
  1756  			o := wso.Clone()
  1757  			o.Websocket.Headers = map[string]string{"Content-Length": "0"}
  1758  			return o
  1759  		}, `websocket: invalid header "Content-Length" not allowed`},
  1760  		{"header with connection", func() *Options {
  1761  			o := wso.Clone()
  1762  			o.Websocket.Headers = map[string]string{"Connection": "Upgrade"}
  1763  			return o
  1764  		}, `websocket: invalid header "Connection" not allowed`},
  1765  		{"header with upgrade", func() *Options {
  1766  			o := wso.Clone()
  1767  			o.Websocket.Headers = map[string]string{"Upgrade": "websocket"}
  1768  			return o
  1769  		}, `websocket: invalid header "Upgrade" not allowed`},
  1770  		{"header with Nats-No-Masking", func() *Options {
  1771  			o := wso.Clone()
  1772  			o.Websocket.Headers = map[string]string{"Nats-No-Masking": "false"}
  1773  			return o
  1774  		}, `websocket: invalid header "Nats-No-Masking" not allowed`},
  1775  	} {
  1776  		t.Run(test.name, func(t *testing.T) {
  1777  			err := validateWebsocketOptions(test.getOpts())
  1778  			if test.err == "" && err != nil {
  1779  				t.Fatalf("Unexpected error: %v", err)
  1780  			} else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) {
  1781  				t.Fatalf("Expected error to contain %q, got %v", test.err, err)
  1782  			}
  1783  		})
  1784  	}
  1785  }
  1786  
  1787  func TestWSSetOriginOptions(t *testing.T) {
  1788  	o := testWSOptions()
  1789  	for _, test := range []struct {
  1790  		content string
  1791  		err     string
  1792  	}{
  1793  		{"@@@://host.com/", "invalid URI"},
  1794  		{"http://this:is:bad:url/", "invalid port"},
  1795  	} {
  1796  		t.Run(test.err, func(t *testing.T) {
  1797  			o.Websocket.AllowedOrigins = []string{test.content}
  1798  			s := &Server{}
  1799  			l := &captureErrorLogger{errCh: make(chan string, 1)}
  1800  			s.SetLogger(l, false, false)
  1801  			s.wsSetOriginOptions(&o.Websocket)
  1802  			select {
  1803  			case e := <-l.errCh:
  1804  				if !strings.Contains(e, test.err) {
  1805  					t.Fatalf("Unexpected error: %v", e)
  1806  				}
  1807  			case <-time.After(50 * time.Millisecond):
  1808  				t.Fatalf("Did not get the error")
  1809  			}
  1810  
  1811  		})
  1812  	}
  1813  }
  1814  
  1815  type captureFatalLogger struct {
  1816  	DummyLogger
  1817  	fatalCh chan string
  1818  }
  1819  
  1820  func (l *captureFatalLogger) Fatalf(format string, v ...any) {
  1821  	select {
  1822  	case l.fatalCh <- fmt.Sprintf(format, v...):
  1823  	default:
  1824  	}
  1825  }
  1826  
  1827  func TestWSFailureToStartServer(t *testing.T) {
  1828  	// Create a listener to use a port
  1829  	l, err := net.Listen("tcp", "127.0.0.1:0")
  1830  	if err != nil {
  1831  		t.Fatalf("Error listening: %v", err)
  1832  	}
  1833  	defer l.Close()
  1834  
  1835  	o := testWSOptions()
  1836  	// Make sure we don't have unnecessary listen ports opened.
  1837  	o.HTTPPort = 0
  1838  	o.Cluster.Port = 0
  1839  	o.Gateway.Name = ""
  1840  	o.Gateway.Port = 0
  1841  	o.LeafNode.Port = 0
  1842  	o.Websocket.Port = l.Addr().(*net.TCPAddr).Port
  1843  	s, err := NewServer(o)
  1844  	if err != nil {
  1845  		t.Fatalf("Error creating server: %v", err)
  1846  	}
  1847  	defer s.Shutdown()
  1848  	logger := &captureFatalLogger{fatalCh: make(chan string, 1)}
  1849  	s.SetLogger(logger, false, false)
  1850  
  1851  	wg := sync.WaitGroup{}
  1852  	wg.Add(1)
  1853  	go func() {
  1854  		s.Start()
  1855  		wg.Done()
  1856  	}()
  1857  
  1858  	select {
  1859  	case e := <-logger.fatalCh:
  1860  		if !strings.Contains(e, "Unable to listen") {
  1861  			t.Fatalf("Unexpected error: %v", e)
  1862  		}
  1863  	case <-time.After(2 * time.Second):
  1864  		t.Fatalf("Should have reported a fatal error")
  1865  	}
  1866  	// Since this is a test and the process does not actually
  1867  	// exit on Fatal error, wait for the client port to be
  1868  	// ready so when we shutdown we don't leave the accept
  1869  	// loop hanging.
  1870  	checkFor(t, time.Second, 15*time.Millisecond, func() error {
  1871  		s.mu.Lock()
  1872  		ready := s.listener != nil
  1873  		s.mu.Unlock()
  1874  		if !ready {
  1875  			return fmt.Errorf("client accept loop not started yet")
  1876  		}
  1877  		return nil
  1878  	})
  1879  	s.Shutdown()
  1880  	wg.Wait()
  1881  }
  1882  
  1883  func TestWSAbnormalFailureOfWebServer(t *testing.T) {
  1884  	o := testWSOptions()
  1885  	s := RunServer(o)
  1886  	defer s.Shutdown()
  1887  	logger := &captureFatalLogger{fatalCh: make(chan string, 1)}
  1888  	s.SetLogger(logger, false, false)
  1889  
  1890  	// Now close the WS listener to cause a WebServer error
  1891  	s.mu.Lock()
  1892  	s.websocket.listener.Close()
  1893  	s.mu.Unlock()
  1894  
  1895  	select {
  1896  	case e := <-logger.fatalCh:
  1897  		if !strings.Contains(e, "websocket listener error") {
  1898  			t.Fatalf("Unexpected error: %v", e)
  1899  		}
  1900  	case <-time.After(2 * time.Second):
  1901  		t.Fatalf("Should have reported a fatal error")
  1902  	}
  1903  }
  1904  
  1905  type testWSClientOptions struct {
  1906  	compress, web        bool
  1907  	host                 string
  1908  	port                 int
  1909  	extraHeaders         map[string][]string
  1910  	noTLS                bool
  1911  	path                 string
  1912  	extraResponseHeaders map[string]string
  1913  }
  1914  
  1915  func testNewWSClient(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte) {
  1916  	t.Helper()
  1917  	c, br, info, err := testNewWSClientWithError(t, o)
  1918  	if err != nil {
  1919  		t.Fatal(err)
  1920  	}
  1921  	return c, br, info
  1922  }
  1923  
  1924  func testNewWSClientWithError(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte, error) {
  1925  	addr := fmt.Sprintf("%s:%d", o.host, o.port)
  1926  	wsc, err := net.Dial("tcp", addr)
  1927  	if err != nil {
  1928  		return nil, nil, nil, fmt.Errorf("Error creating ws connection: %v", err)
  1929  	}
  1930  	if !o.noTLS {
  1931  		wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
  1932  		wsc.SetDeadline(time.Now().Add(time.Second))
  1933  		if err := wsc.(*tls.Conn).Handshake(); err != nil {
  1934  			return nil, nil, nil, fmt.Errorf("Error during handshake: %v", err)
  1935  		}
  1936  		wsc.SetDeadline(time.Time{})
  1937  	}
  1938  	req := testWSCreateValidReq()
  1939  	if o.compress {
  1940  		req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate")
  1941  	}
  1942  	if o.web {
  1943  		req.Header.Set("User-Agent", "Mozilla/5.0")
  1944  	}
  1945  	if len(o.extraHeaders) > 0 {
  1946  		for hdr, values := range o.extraHeaders {
  1947  			if len(values) == 0 {
  1948  				req.Header.Set(hdr, _EMPTY_)
  1949  				continue
  1950  			}
  1951  			req.Header.Set(hdr, values[0])
  1952  			for i := 1; i < len(values); i++ {
  1953  				req.Header.Add(hdr, values[i])
  1954  			}
  1955  		}
  1956  	}
  1957  	req.URL, _ = url.Parse("wss://" + addr + o.path)
  1958  	if err := req.Write(wsc); err != nil {
  1959  		return nil, nil, nil, fmt.Errorf("Error sending request: %v", err)
  1960  	}
  1961  	br := bufio.NewReader(wsc)
  1962  	resp, err := http.ReadResponse(br, req)
  1963  	if err != nil {
  1964  		return nil, nil, nil, fmt.Errorf("Error reading response: %v", err)
  1965  	}
  1966  	defer resp.Body.Close()
  1967  	if resp.StatusCode != http.StatusSwitchingProtocols {
  1968  		return nil, nil, nil, fmt.Errorf("Expected response status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
  1969  	}
  1970  	for k, v := range o.extraResponseHeaders {
  1971  		if value := resp.Header.Get(k); value != v {
  1972  			return nil, nil, nil, fmt.Errorf("Expected header %q to be %q, got %q", k, v, value)
  1973  		}
  1974  	}
  1975  	var info []byte
  1976  	if o.path == mqttWSPath {
  1977  		if v := resp.Header[wsSecProto]; len(v) != 1 || v[0] != wsMQTTSecProtoVal {
  1978  			return nil, nil, nil, fmt.Errorf("No mqtt protocol in header: %v", resp.Header)
  1979  		}
  1980  	} else {
  1981  		// Wait for the INFO
  1982  		info = testWSReadFrame(t, br)
  1983  		if !bytes.HasPrefix(info, []byte("INFO {")) {
  1984  			return nil, nil, nil, fmt.Errorf("Expected INFO, got %s", info)
  1985  		}
  1986  	}
  1987  	return wsc, br, info, nil
  1988  }
  1989  
  1990  type testClaimsOptions struct {
  1991  	nac            *jwt.AccountClaims
  1992  	nuc            *jwt.UserClaims
  1993  	connectRequest any
  1994  	dontSign       bool
  1995  	expectAnswer   string
  1996  }
  1997  
  1998  func testWSWithClaims(t *testing.T, s *Server, o testWSClientOptions, tclm testClaimsOptions) (kp nkeys.KeyPair, conn net.Conn, rdr *bufio.Reader, auth_was_required bool) {
  1999  	t.Helper()
  2000  
  2001  	okp, _ := nkeys.FromSeed(oSeed)
  2002  
  2003  	akp, _ := nkeys.CreateAccount()
  2004  	apub, _ := akp.PublicKey()
  2005  	if tclm.nac == nil {
  2006  		tclm.nac = jwt.NewAccountClaims(apub)
  2007  	} else {
  2008  		tclm.nac.Subject = apub
  2009  	}
  2010  	ajwt, err := tclm.nac.Encode(okp)
  2011  	if err != nil {
  2012  		t.Fatalf("Error generating account JWT: %v", err)
  2013  	}
  2014  
  2015  	nkp, _ := nkeys.CreateUser()
  2016  	pub, _ := nkp.PublicKey()
  2017  	if tclm.nuc == nil {
  2018  		tclm.nuc = jwt.NewUserClaims(pub)
  2019  	} else {
  2020  		tclm.nuc.Subject = pub
  2021  	}
  2022  	jwt, err := tclm.nuc.Encode(akp)
  2023  	if err != nil {
  2024  		t.Fatalf("Error generating user JWT: %v", err)
  2025  	}
  2026  
  2027  	addAccountToMemResolver(s, apub, ajwt)
  2028  
  2029  	c, cr, l := testNewWSClient(t, o)
  2030  
  2031  	var info struct {
  2032  		Nonce        string `json:"nonce,omitempty"`
  2033  		AuthRequired bool   `json:"auth_required,omitempty"`
  2034  	}
  2035  
  2036  	if err := json.Unmarshal([]byte(l[5:]), &info); err != nil {
  2037  		t.Fatal(err)
  2038  	}
  2039  	if info.AuthRequired {
  2040  		cs := ""
  2041  		if tclm.connectRequest != nil {
  2042  			customReq, err := json.Marshal(tclm.connectRequest)
  2043  			if err != nil {
  2044  				t.Fatal(err)
  2045  			}
  2046  			// PING needed to flush the +OK/-ERR to us.
  2047  			cs = fmt.Sprintf("CONNECT %v\r\nPING\r\n", string(customReq))
  2048  		} else if !tclm.dontSign {
  2049  			// Sign Nonce
  2050  			sigraw, _ := nkp.Sign([]byte(info.Nonce))
  2051  			sig := base64.RawURLEncoding.EncodeToString(sigraw)
  2052  			cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"sig\":\"%s\",\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt, sig)
  2053  		} else {
  2054  			cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt)
  2055  		}
  2056  		wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(cs))
  2057  		c.Write(wsmsg)
  2058  		l = testWSReadFrame(t, cr)
  2059  		if !strings.HasPrefix(string(l), tclm.expectAnswer) {
  2060  			t.Fatalf("Expected %q, got %q", tclm.expectAnswer, l)
  2061  		}
  2062  	}
  2063  	return akp, c, cr, info.AuthRequired
  2064  }
  2065  
  2066  func setupAddTrusted(o *Options) {
  2067  	kp, _ := nkeys.FromSeed(oSeed)
  2068  	pub, _ := kp.PublicKey()
  2069  	o.TrustedKeys = []string{pub}
  2070  }
  2071  
  2072  func setupAddCookie(o *Options) {
  2073  	o.Websocket.JWTCookie = "jwt"
  2074  }
  2075  
  2076  func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int, cookies ...string) (net.Conn, *bufio.Reader, []byte) {
  2077  	t.Helper()
  2078  	opts := testWSClientOptions{
  2079  		compress: compress,
  2080  		web:      web,
  2081  		host:     host,
  2082  		port:     port,
  2083  	}
  2084  
  2085  	if len(cookies) > 0 {
  2086  		opts.extraHeaders = map[string][]string{}
  2087  		opts.extraHeaders["Cookie"] = cookies
  2088  	}
  2089  	return testNewWSClient(t, opts)
  2090  }
  2091  
  2092  func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) {
  2093  	wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port)
  2094  	// Send CONNECT and PING
  2095  	wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
  2096  	if _, err := wsc.Write(wsmsg); err != nil {
  2097  		t.Fatalf("Error sending message: %v", err)
  2098  	}
  2099  	// Wait for the PONG
  2100  	if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  2101  		t.Fatalf("Expected PONG, got %s", msg)
  2102  	}
  2103  	return wsc, br
  2104  }
  2105  
  2106  func testWSReadFrame(t testing.TB, br *bufio.Reader) []byte {
  2107  	t.Helper()
  2108  	fh := [2]byte{}
  2109  	if _, err := io.ReadAtLeast(br, fh[:2], 2); err != nil {
  2110  		t.Fatalf("Error reading frame: %v", err)
  2111  	}
  2112  	fc := fh[0]&wsRsv1Bit != 0
  2113  	sb := fh[1]
  2114  	size := 0
  2115  	switch {
  2116  	case sb <= 125:
  2117  		size = int(sb)
  2118  	case sb == 126:
  2119  		tmp := [2]byte{}
  2120  		if _, err := io.ReadAtLeast(br, tmp[:2], 2); err != nil {
  2121  			t.Fatalf("Error reading frame: %v", err)
  2122  		}
  2123  		size = int(binary.BigEndian.Uint16(tmp[:2]))
  2124  	case sb == 127:
  2125  		tmp := [8]byte{}
  2126  		if _, err := io.ReadAtLeast(br, tmp[:8], 8); err != nil {
  2127  			t.Fatalf("Error reading frame: %v", err)
  2128  		}
  2129  		size = int(binary.BigEndian.Uint64(tmp[:8]))
  2130  	}
  2131  	buf := make([]byte, size)
  2132  	if _, err := io.ReadAtLeast(br, buf, size); err != nil {
  2133  		t.Fatalf("Error reading frame: %v", err)
  2134  	}
  2135  	if !fc {
  2136  		return buf
  2137  	}
  2138  	buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff)
  2139  	dbr := bytes.NewBuffer(buf)
  2140  	d := flate.NewReader(dbr)
  2141  	uncompressed, err := io.ReadAll(d)
  2142  	if err != nil {
  2143  		t.Fatalf("Error reading frame: %v", err)
  2144  	}
  2145  	return uncompressed
  2146  }
  2147  
  2148  func TestWSPubSub(t *testing.T) {
  2149  	for _, test := range []struct {
  2150  		name        string
  2151  		compression bool
  2152  	}{
  2153  		{"no compression", false},
  2154  		{"compression", true},
  2155  	} {
  2156  		t.Run(test.name, func(t *testing.T) {
  2157  			o := testWSOptions()
  2158  			if test.compression {
  2159  				o.Websocket.Compression = true
  2160  			}
  2161  			s := RunServer(o)
  2162  			defer s.Shutdown()
  2163  
  2164  			// Create a regular client to subscribe
  2165  			nc := natsConnect(t, s.ClientURL())
  2166  			defer nc.Close()
  2167  			nsub := natsSubSync(t, nc, "foo")
  2168  			checkExpectedSubs(t, 1, s)
  2169  
  2170  			// Now create a WS client and send a message on "foo"
  2171  			wsc, br := testWSCreateClient(t, test.compression, false, o.Websocket.Host, o.Websocket.Port)
  2172  			defer wsc.Close()
  2173  
  2174  			// Send a WS message for "PUB foo 2\r\nok\r\n"
  2175  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n"))
  2176  			if _, err := wsc.Write(wsmsg); err != nil {
  2177  				t.Fatalf("Error sending message: %v", err)
  2178  			}
  2179  
  2180  			// Now check that message is received
  2181  			msg := natsNexMsg(t, nsub, time.Second)
  2182  			if string(msg.Data) != "from ws" {
  2183  				t.Fatalf("Expected message to be %q, got %q", "ok", string(msg.Data))
  2184  			}
  2185  
  2186  			// Now do reverse, create a subscription on WS client on bar
  2187  			wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB bar 1\r\n"))
  2188  			if _, err := wsc.Write(wsmsg); err != nil {
  2189  				t.Fatalf("Error sending subscription: %v", err)
  2190  			}
  2191  			// Wait for it to be registered on server
  2192  			checkExpectedSubs(t, 2, s)
  2193  			// Now publish from NATS connection and verify received on WS client
  2194  			natsPub(t, nc, "bar", []byte("from nats"))
  2195  			natsFlush(t, nc)
  2196  
  2197  			// Check for the "from nats" message...
  2198  			// Set some deadline so we are not stuck forever on failure
  2199  			wsc.SetReadDeadline(time.Now().Add(10 * time.Second))
  2200  			ok := 0
  2201  			for {
  2202  				line, _, err := br.ReadLine()
  2203  				if err != nil {
  2204  					t.Fatalf("Error reading: %v", err)
  2205  				}
  2206  				// Note that this works even in compression test because those
  2207  				// texts are likely not to be compressed, but compression code is
  2208  				// still executed.
  2209  				if ok == 0 && bytes.Contains(line, []byte("MSG bar 1 9")) {
  2210  					ok = 1
  2211  					continue
  2212  				} else if ok == 1 && bytes.Contains(line, []byte("from nats")) {
  2213  					break
  2214  				}
  2215  			}
  2216  		})
  2217  	}
  2218  }
  2219  
  2220  func TestWSTLSConnection(t *testing.T) {
  2221  	o := testWSOptions()
  2222  	s := RunServer(o)
  2223  	defer s.Shutdown()
  2224  
  2225  	addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
  2226  
  2227  	for _, test := range []struct {
  2228  		name   string
  2229  		useTLS bool
  2230  		status int
  2231  	}{
  2232  		{"client uses TLS", true, http.StatusSwitchingProtocols},
  2233  		{"client does not use TLS", false, http.StatusBadRequest},
  2234  	} {
  2235  		t.Run(test.name, func(t *testing.T) {
  2236  			wsc, err := net.Dial("tcp", addr)
  2237  			if err != nil {
  2238  				t.Fatalf("Error creating ws connection: %v", err)
  2239  			}
  2240  			defer wsc.Close()
  2241  			if test.useTLS {
  2242  				wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
  2243  				if err := wsc.(*tls.Conn).Handshake(); err != nil {
  2244  					t.Fatalf("Error during handshake: %v", err)
  2245  				}
  2246  			}
  2247  			req := testWSCreateValidReq()
  2248  			var scheme string
  2249  			if test.useTLS {
  2250  				scheme = "s"
  2251  			}
  2252  			req.URL, _ = url.Parse("ws" + scheme + "://" + addr)
  2253  			if err := req.Write(wsc); err != nil {
  2254  				t.Fatalf("Error sending request: %v", err)
  2255  			}
  2256  			br := bufio.NewReader(wsc)
  2257  			resp, err := http.ReadResponse(br, req)
  2258  			if err != nil {
  2259  				t.Fatalf("Error reading response: %v", err)
  2260  			}
  2261  			defer resp.Body.Close()
  2262  			if resp.StatusCode != test.status {
  2263  				t.Fatalf("Expected status %v, got %v", test.status, resp.StatusCode)
  2264  			}
  2265  		})
  2266  	}
  2267  }
  2268  
  2269  func TestWSTLSVerifyClientCert(t *testing.T) {
  2270  	o := testWSOptions()
  2271  	tc := &TLSConfigOpts{
  2272  		CertFile: "../test/configs/certs/server-cert.pem",
  2273  		KeyFile:  "../test/configs/certs/server-key.pem",
  2274  		CaFile:   "../test/configs/certs/ca.pem",
  2275  		Verify:   true,
  2276  	}
  2277  	tlsc, err := GenTLSConfig(tc)
  2278  	if err != nil {
  2279  		t.Fatalf("Error creating tls config: %v", err)
  2280  	}
  2281  	o.Websocket.TLSConfig = tlsc
  2282  	s := RunServer(o)
  2283  	defer s.Shutdown()
  2284  
  2285  	addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
  2286  
  2287  	for _, test := range []struct {
  2288  		name        string
  2289  		provideCert bool
  2290  	}{
  2291  		{"client provides cert", true},
  2292  		{"client does not provide cert", false},
  2293  	} {
  2294  		t.Run(test.name, func(t *testing.T) {
  2295  			wsc, err := net.Dial("tcp", addr)
  2296  			if err != nil {
  2297  				t.Fatalf("Error creating ws connection: %v", err)
  2298  			}
  2299  			defer wsc.Close()
  2300  			tlsc := &tls.Config{}
  2301  			if test.provideCert {
  2302  				tc := &TLSConfigOpts{
  2303  					CertFile: "../test/configs/certs/client-cert.pem",
  2304  					KeyFile:  "../test/configs/certs/client-key.pem",
  2305  				}
  2306  				var err error
  2307  				tlsc, err = GenTLSConfig(tc)
  2308  				if err != nil {
  2309  					t.Fatalf("Error generating tls config: %v", err)
  2310  				}
  2311  			}
  2312  			tlsc.InsecureSkipVerify = true
  2313  			wsc = tls.Client(wsc, tlsc)
  2314  			if err := wsc.(*tls.Conn).Handshake(); err != nil {
  2315  				t.Fatalf("Error during handshake: %v", err)
  2316  			}
  2317  			req := testWSCreateValidReq()
  2318  			req.URL, _ = url.Parse("wss://" + addr)
  2319  			if err := req.Write(wsc); err != nil {
  2320  				t.Fatalf("Error sending request: %v", err)
  2321  			}
  2322  			br := bufio.NewReader(wsc)
  2323  			resp, err := http.ReadResponse(br, req)
  2324  			if resp != nil {
  2325  				resp.Body.Close()
  2326  			}
  2327  			if !test.provideCert {
  2328  				if err == nil {
  2329  					t.Fatal("Expected error, did not get one")
  2330  				} else if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "certificate required") {
  2331  					t.Fatalf("Unexpected error: %v", err)
  2332  				}
  2333  				return
  2334  			}
  2335  			if err != nil {
  2336  				t.Fatalf("Unexpected error: %v", err)
  2337  			}
  2338  			if resp.StatusCode != http.StatusSwitchingProtocols {
  2339  				t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
  2340  			}
  2341  		})
  2342  	}
  2343  }
  2344  
  2345  func testCreateAllowedConnectionTypes(list []string) map[string]struct{} {
  2346  	if len(list) == 0 {
  2347  		return nil
  2348  	}
  2349  	m := make(map[string]struct{}, len(list))
  2350  	for _, l := range list {
  2351  		m[l] = struct{}{}
  2352  	}
  2353  	return m
  2354  }
  2355  
  2356  func TestWSTLSVerifyAndMap(t *testing.T) {
  2357  	accName := "MyAccount"
  2358  	acc := NewAccount(accName)
  2359  	certUserName := "CN=example.com,OU=NATS.io"
  2360  	users := []*User{{Username: certUserName, Account: acc}}
  2361  
  2362  	for _, test := range []struct {
  2363  		name        string
  2364  		filtering   bool
  2365  		provideCert bool
  2366  	}{
  2367  		{"no filtering, client provides cert", false, true},
  2368  		{"no filtering, client does not provide cert", false, false},
  2369  		{"filtering, client provides cert", true, true},
  2370  		{"filtering, client does not provide cert", true, false},
  2371  		{"no users override, client provides cert", false, true},
  2372  		{"no users override, client does not provide cert", false, false},
  2373  		{"users override, client provides cert", true, true},
  2374  		{"users override, client does not provide cert", true, false},
  2375  	} {
  2376  		t.Run(test.name, func(t *testing.T) {
  2377  			o := testWSOptions()
  2378  			o.Accounts = []*Account{acc}
  2379  			o.Users = users
  2380  			if test.filtering {
  2381  				o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket})
  2382  			}
  2383  			tc := &TLSConfigOpts{
  2384  				CertFile: "../test/configs/certs/tlsauth/server.pem",
  2385  				KeyFile:  "../test/configs/certs/tlsauth/server-key.pem",
  2386  				CaFile:   "../test/configs/certs/tlsauth/ca.pem",
  2387  				Verify:   true,
  2388  			}
  2389  			tlsc, err := GenTLSConfig(tc)
  2390  			if err != nil {
  2391  				t.Fatalf("Error creating tls config: %v", err)
  2392  			}
  2393  			o.Websocket.TLSConfig = tlsc
  2394  			o.Websocket.TLSMap = true
  2395  			s := RunServer(o)
  2396  			defer s.Shutdown()
  2397  
  2398  			addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
  2399  			wsc, err := net.Dial("tcp", addr)
  2400  			if err != nil {
  2401  				t.Fatalf("Error creating ws connection: %v", err)
  2402  			}
  2403  			defer wsc.Close()
  2404  			tlscc := &tls.Config{}
  2405  			if test.provideCert {
  2406  				tc := &TLSConfigOpts{
  2407  					CertFile: "../test/configs/certs/tlsauth/client.pem",
  2408  					KeyFile:  "../test/configs/certs/tlsauth/client-key.pem",
  2409  				}
  2410  				var err error
  2411  				tlscc, err = GenTLSConfig(tc)
  2412  				if err != nil {
  2413  					t.Fatalf("Error generating tls config: %v", err)
  2414  				}
  2415  			}
  2416  			tlscc.InsecureSkipVerify = true
  2417  			wsc = tls.Client(wsc, tlscc)
  2418  			if err := wsc.(*tls.Conn).Handshake(); err != nil {
  2419  				t.Fatalf("Error during handshake: %v", err)
  2420  			}
  2421  			req := testWSCreateValidReq()
  2422  			req.URL, _ = url.Parse("wss://" + addr)
  2423  			if err := req.Write(wsc); err != nil {
  2424  				t.Fatalf("Error sending request: %v", err)
  2425  			}
  2426  			br := bufio.NewReader(wsc)
  2427  			resp, err := http.ReadResponse(br, req)
  2428  			if resp != nil {
  2429  				resp.Body.Close()
  2430  			}
  2431  			if !test.provideCert {
  2432  				if err == nil {
  2433  					t.Fatal("Expected error, did not get one")
  2434  				} else if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "certificate required") {
  2435  					t.Fatalf("Unexpected error: %v", err)
  2436  				}
  2437  				return
  2438  			}
  2439  			if err != nil {
  2440  				t.Fatalf("Unexpected error: %v", err)
  2441  			}
  2442  			if resp.StatusCode != http.StatusSwitchingProtocols {
  2443  				t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
  2444  			}
  2445  			// Wait for the INFO
  2446  			l := testWSReadFrame(t, br)
  2447  			if !bytes.HasPrefix(l, []byte("INFO {")) {
  2448  				t.Fatalf("Expected INFO, got %s", l)
  2449  			}
  2450  			var info serverInfo
  2451  			if err := json.Unmarshal(l[5:], &info); err != nil {
  2452  				t.Fatalf("Unable to unmarshal info: %v", err)
  2453  			}
  2454  			// Send CONNECT and PING
  2455  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
  2456  			if _, err := wsc.Write(wsmsg); err != nil {
  2457  				t.Fatalf("Error sending message: %v", err)
  2458  			}
  2459  			// Wait for the PONG
  2460  			if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  2461  				t.Fatalf("Expected PONG, got %s", msg)
  2462  			}
  2463  
  2464  			var uname string
  2465  			var accname string
  2466  			c := s.getClient(info.CID)
  2467  			if c != nil {
  2468  				c.mu.Lock()
  2469  				uname = c.opts.Username
  2470  				if c.acc != nil {
  2471  					accname = c.acc.GetName()
  2472  				}
  2473  				c.mu.Unlock()
  2474  			}
  2475  			if uname != certUserName {
  2476  				t.Fatalf("Expected username %q, got %q", certUserName, uname)
  2477  			}
  2478  			if accname != accName {
  2479  				t.Fatalf("Expected account %q, got %v", accName, accname)
  2480  			}
  2481  		})
  2482  	}
  2483  }
  2484  
  2485  func TestWSHandshakeTimeout(t *testing.T) {
  2486  	o := testWSOptions()
  2487  	o.Websocket.HandshakeTimeout = time.Millisecond
  2488  	tc := &TLSConfigOpts{
  2489  		CertFile: "./configs/certs/server.pem",
  2490  		KeyFile:  "./configs/certs/key.pem",
  2491  	}
  2492  	o.Websocket.TLSConfig, _ = GenTLSConfig(tc)
  2493  	s := RunServer(o)
  2494  	defer s.Shutdown()
  2495  
  2496  	logger := &captureErrorLogger{errCh: make(chan string, 1)}
  2497  	s.SetLogger(logger, false, false)
  2498  
  2499  	addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
  2500  	wsc, err := net.Dial("tcp", addr)
  2501  	if err != nil {
  2502  		t.Fatalf("Error creating ws connection: %v", err)
  2503  	}
  2504  	defer wsc.Close()
  2505  
  2506  	// Delay the handshake
  2507  	wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
  2508  	time.Sleep(20 * time.Millisecond)
  2509  	// We expect error since the server should have cut us off
  2510  	if err := wsc.(*tls.Conn).Handshake(); err == nil {
  2511  		t.Fatal("Expected error during handshake")
  2512  	}
  2513  
  2514  	// Check that server logs error
  2515  	select {
  2516  	case e := <-logger.errCh:
  2517  		// Check that log starts with "websocket: "
  2518  		if !strings.HasPrefix(e, "websocket: ") {
  2519  			t.Fatalf("Wrong log line start: %s", e)
  2520  		}
  2521  		if !strings.Contains(e, "timeout") {
  2522  			t.Fatalf("Unexpected error: %v", e)
  2523  		}
  2524  	case <-time.After(time.Second):
  2525  		t.Fatalf("Should have timed-out")
  2526  	}
  2527  }
  2528  
  2529  func TestWSServerReportUpgradeFailure(t *testing.T) {
  2530  	o := testWSOptions()
  2531  	s := RunServer(o)
  2532  	defer s.Shutdown()
  2533  
  2534  	logger := &captureErrorLogger{errCh: make(chan string, 1)}
  2535  	s.SetLogger(logger, false, false)
  2536  
  2537  	addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port)
  2538  	req := testWSCreateValidReq()
  2539  	req.URL, _ = url.Parse("wss://" + addr)
  2540  
  2541  	wsc, err := net.Dial("tcp", addr)
  2542  	if err != nil {
  2543  		t.Fatalf("Error creating ws connection: %v", err)
  2544  	}
  2545  	defer wsc.Close()
  2546  	wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
  2547  	if err := wsc.(*tls.Conn).Handshake(); err != nil {
  2548  		t.Fatalf("Error during handshake: %v", err)
  2549  	}
  2550  	// Remove a required field from the request to have it fail
  2551  	req.Header.Del("Connection")
  2552  	// Send the request
  2553  	if err := req.Write(wsc); err != nil {
  2554  		t.Fatalf("Error sending request: %v", err)
  2555  	}
  2556  	br := bufio.NewReader(wsc)
  2557  	resp, err := http.ReadResponse(br, req)
  2558  	if err != nil {
  2559  		t.Fatalf("Error reading response: %v", err)
  2560  	}
  2561  	defer resp.Body.Close()
  2562  	if resp.StatusCode != http.StatusBadRequest {
  2563  		t.Fatalf("Expected status %v, got %v", http.StatusBadRequest, resp.StatusCode)
  2564  	}
  2565  
  2566  	// Check that server logs error
  2567  	select {
  2568  	case e := <-logger.errCh:
  2569  		if !strings.Contains(e, "invalid value for header 'Connection'") {
  2570  			t.Fatalf("Unexpected error: %v", e)
  2571  		}
  2572  		// The client IP's local should be printed as a remote from server perspective.
  2573  		clientIP := wsc.LocalAddr().String()
  2574  		if !strings.HasPrefix(e, clientIP) {
  2575  			t.Fatalf("IP should have been logged, it was not: %v", e)
  2576  		}
  2577  	case <-time.After(time.Second):
  2578  		t.Fatalf("Should have timed-out")
  2579  	}
  2580  }
  2581  
  2582  func TestWSCloseMsgSendOnConnectionClose(t *testing.T) {
  2583  	o := testWSOptions()
  2584  	s := RunServer(o)
  2585  	defer s.Shutdown()
  2586  
  2587  	wsc, br := testWSCreateClient(t, false, false, o.Websocket.Host, o.Websocket.Port)
  2588  	defer wsc.Close()
  2589  
  2590  	checkClientsCount(t, s, 1)
  2591  	var c *client
  2592  	s.mu.Lock()
  2593  	for _, cli := range s.clients {
  2594  		c = cli
  2595  		break
  2596  	}
  2597  	s.mu.Unlock()
  2598  
  2599  	c.closeConnection(ProtocolViolation)
  2600  	msg := testWSReadFrame(t, br)
  2601  	if len(msg) < 2 {
  2602  		t.Fatalf("Should have 2 bytes to represent the status, got %v", msg)
  2603  	}
  2604  	if sc := int(binary.BigEndian.Uint16(msg[:2])); sc != wsCloseStatusProtocolError {
  2605  		t.Fatalf("Expected status to be %v, got %v", wsCloseStatusProtocolError, sc)
  2606  	}
  2607  	expectedPayload := ProtocolViolation.String()
  2608  	if p := string(msg[2:]); p != expectedPayload {
  2609  		t.Fatalf("Expected payload to be %q, got %q", expectedPayload, p)
  2610  	}
  2611  }
  2612  
  2613  func TestWSAdvertise(t *testing.T) {
  2614  	o := testWSOptions()
  2615  	o.Cluster.Port = 0
  2616  	o.HTTPPort = 0
  2617  	o.Websocket.Advertise = "xxx:host:yyy"
  2618  	s, err := NewServer(o)
  2619  	if err != nil {
  2620  		t.Fatalf("Unexpected error: %v", err)
  2621  	}
  2622  	defer s.Shutdown()
  2623  	l := &captureFatalLogger{fatalCh: make(chan string, 1)}
  2624  	s.SetLogger(l, false, false)
  2625  	s.Start()
  2626  	select {
  2627  	case e := <-l.fatalCh:
  2628  		if !strings.Contains(e, "Unable to get websocket connect URLs") {
  2629  			t.Fatalf("Unexpected error: %q", e)
  2630  		}
  2631  	case <-time.After(time.Second):
  2632  		t.Fatal("Should have failed to start")
  2633  	}
  2634  	s.Shutdown()
  2635  
  2636  	o1 := testWSOptions()
  2637  	o1.Websocket.Advertise = "host1:1234"
  2638  	s1 := RunServer(o1)
  2639  	defer s1.Shutdown()
  2640  
  2641  	wsc, br := testWSCreateClient(t, false, false, o1.Websocket.Host, o1.Websocket.Port)
  2642  	defer wsc.Close()
  2643  
  2644  	o2 := testWSOptions()
  2645  	o2.Websocket.Advertise = "host2:5678"
  2646  	o2.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", o1.Cluster.Host, o1.Cluster.Port))
  2647  	s2 := RunServer(o2)
  2648  	defer s2.Shutdown()
  2649  
  2650  	checkInfo := func(expected []string) {
  2651  		t.Helper()
  2652  		infob := testWSReadFrame(t, br)
  2653  		info := &Info{}
  2654  		json.Unmarshal(infob[5:], info)
  2655  		if n := len(info.ClientConnectURLs); n != len(expected) {
  2656  			t.Fatalf("Unexpected info: %+v", info)
  2657  		}
  2658  		good := 0
  2659  		for _, u := range info.ClientConnectURLs {
  2660  			for _, eu := range expected {
  2661  				if u == eu {
  2662  					good++
  2663  				}
  2664  			}
  2665  		}
  2666  		if good != len(expected) {
  2667  			t.Fatalf("Unexpected connect urls: %q", info.ClientConnectURLs)
  2668  		}
  2669  	}
  2670  	checkInfo([]string{"host1:1234", "host2:5678"})
  2671  
  2672  	// Now shutdown s2 and expect another INFO
  2673  	s2.Shutdown()
  2674  	checkInfo([]string{"host1:1234"})
  2675  
  2676  	// Restart with another advertise and check that it gets updated
  2677  	o2.Websocket.Advertise = "host3:9012"
  2678  	s2 = RunServer(o2)
  2679  	defer s2.Shutdown()
  2680  	checkInfo([]string{"host1:1234", "host3:9012"})
  2681  }
  2682  
  2683  func TestWSFrameOutbound(t *testing.T) {
  2684  	for _, test := range []struct {
  2685  		name         string
  2686  		maskingWrite bool
  2687  	}{
  2688  		{"no write masking", false},
  2689  		{"write masking", true},
  2690  	} {
  2691  		t.Run(test.name, func(t *testing.T) {
  2692  			c, _, _ := testWSSetupForRead()
  2693  			c.ws.maskwrite = test.maskingWrite
  2694  
  2695  			getKey := func(buf []byte) []byte {
  2696  				return buf[len(buf)-4:]
  2697  			}
  2698  
  2699  			var bufs net.Buffers
  2700  			bufs = append(bufs, []byte("this "))
  2701  			bufs = append(bufs, []byte("is "))
  2702  			bufs = append(bufs, []byte("a "))
  2703  			bufs = append(bufs, []byte("set "))
  2704  			bufs = append(bufs, []byte("of "))
  2705  			bufs = append(bufs, []byte("buffers"))
  2706  			en := 2
  2707  			for _, b := range bufs {
  2708  				en += len(b)
  2709  			}
  2710  			if test.maskingWrite {
  2711  				en += 4
  2712  			}
  2713  			c.mu.Lock()
  2714  			c.out.nb = bufs
  2715  			res, n := c.collapsePtoNB()
  2716  			c.mu.Unlock()
  2717  			if n != int64(en) {
  2718  				t.Fatalf("Expected size to be %v, got %v", en, n)
  2719  			}
  2720  			if eb := 1 + len(bufs); eb != len(res) {
  2721  				t.Fatalf("Expected %v buffers, got %v", eb, len(res))
  2722  			}
  2723  			var ob []byte
  2724  			for i := 1; i < len(res); i++ {
  2725  				ob = append(ob, res[i]...)
  2726  			}
  2727  			if test.maskingWrite {
  2728  				wsMaskBuf(getKey(res[0]), ob)
  2729  			}
  2730  			if !bytes.Equal(ob, []byte("this is a set of buffers")) {
  2731  				t.Fatalf("Unexpected outbound: %q", ob)
  2732  			}
  2733  
  2734  			bufs = nil
  2735  			c.out.pb = 0
  2736  			c.ws.fs = 0
  2737  			c.ws.frames = nil
  2738  			c.ws.browser = true
  2739  			bufs = append(bufs, []byte("some smaller "))
  2740  			bufs = append(bufs, []byte("buffers"))
  2741  			bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+10))
  2742  			bufs = append(bufs, []byte("then some more"))
  2743  			en = 2 + len(bufs[0]) + len(bufs[1])
  2744  			en += 4 + len(bufs[2]) - 10
  2745  			en += 2 + len(bufs[3]) + 10
  2746  			c.mu.Lock()
  2747  			c.out.nb = bufs
  2748  			res, n = c.collapsePtoNB()
  2749  			c.mu.Unlock()
  2750  			if test.maskingWrite {
  2751  				en += 3 * 4
  2752  			}
  2753  			if n != int64(en) {
  2754  				t.Fatalf("Expected size to be %v, got %v", en, n)
  2755  			}
  2756  			if len(res) != 8 {
  2757  				t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
  2758  			}
  2759  			if len(res[4]) != wsFrameSizeForBrowsers {
  2760  				t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
  2761  			}
  2762  			if len(res[6]) != 10 {
  2763  				t.Fatalf("Frame 6 should have the partial of 10 bytes, got %v", len(res[6]))
  2764  			}
  2765  			if test.maskingWrite {
  2766  				b := &bytes.Buffer{}
  2767  				key := getKey(res[0])
  2768  				b.Write(res[1])
  2769  				b.Write(res[2])
  2770  				ud := b.Bytes()
  2771  				wsMaskBuf(key, ud)
  2772  				if string(ud) != "some smaller buffers" {
  2773  					t.Fatalf("Unexpected result: %q", ud)
  2774  				}
  2775  
  2776  				b.Reset()
  2777  				key = getKey(res[3])
  2778  				b.Write(res[4])
  2779  				ud = b.Bytes()
  2780  				wsMaskBuf(key, ud)
  2781  				for i := 0; i < len(ud); i++ {
  2782  					if ud[i] != 0 {
  2783  						t.Fatalf("Unexpected result: %v", ud)
  2784  					}
  2785  				}
  2786  
  2787  				b.Reset()
  2788  				key = getKey(res[5])
  2789  				b.Write(res[6])
  2790  				b.Write(res[7])
  2791  				ud = b.Bytes()
  2792  				wsMaskBuf(key, ud)
  2793  				for i := 0; i < len(ud[:10]); i++ {
  2794  					if ud[i] != 0 {
  2795  						t.Fatalf("Unexpected result: %v", ud[:10])
  2796  					}
  2797  				}
  2798  				if string(ud[10:]) != "then some more" {
  2799  					t.Fatalf("Unexpected result: %q", ud[10:])
  2800  				}
  2801  			}
  2802  
  2803  			bufs = nil
  2804  			c.out.pb = 0
  2805  			c.ws.fs = 0
  2806  			c.ws.frames = nil
  2807  			c.ws.browser = true
  2808  			bufs = append(bufs, []byte("some smaller "))
  2809  			bufs = append(bufs, []byte("buffers"))
  2810  			// Have one of the exact max size
  2811  			bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers))
  2812  			bufs = append(bufs, []byte("then some more"))
  2813  			en = 2 + len(bufs[0]) + len(bufs[1])
  2814  			en += 4 + len(bufs[2])
  2815  			en += 2 + len(bufs[3])
  2816  			c.mu.Lock()
  2817  			c.out.nb = bufs
  2818  			res, n = c.collapsePtoNB()
  2819  			c.mu.Unlock()
  2820  			if test.maskingWrite {
  2821  				en += 3 * 4
  2822  			}
  2823  			if n != int64(en) {
  2824  				t.Fatalf("Expected size to be %v, got %v", en, n)
  2825  			}
  2826  			if len(res) != 7 {
  2827  				t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
  2828  			}
  2829  			if len(res[4]) != wsFrameSizeForBrowsers {
  2830  				t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
  2831  			}
  2832  			if test.maskingWrite {
  2833  				key := getKey(res[5])
  2834  				wsMaskBuf(key, res[6])
  2835  			}
  2836  			if string(res[6]) != "then some more" {
  2837  				t.Fatalf("Frame 6 incorrect: %q", res[6])
  2838  			}
  2839  
  2840  			bufs = nil
  2841  			c.out.pb = 0
  2842  			c.ws.fs = 0
  2843  			c.ws.frames = nil
  2844  			c.ws.browser = true
  2845  			bufs = append(bufs, []byte("some smaller "))
  2846  			bufs = append(bufs, []byte("buffers"))
  2847  			// Have one of the exact max size, and last in the list
  2848  			bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers))
  2849  			en = 2 + len(bufs[0]) + len(bufs[1])
  2850  			en += 4 + len(bufs[2])
  2851  			c.mu.Lock()
  2852  			c.out.nb = bufs
  2853  			res, n = c.collapsePtoNB()
  2854  			c.mu.Unlock()
  2855  			if test.maskingWrite {
  2856  				en += 2 * 4
  2857  			}
  2858  			if n != int64(en) {
  2859  				t.Fatalf("Expected size to be %v, got %v", en, n)
  2860  			}
  2861  			if len(res) != 5 {
  2862  				t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
  2863  			}
  2864  			if len(res[4]) != wsFrameSizeForBrowsers {
  2865  				t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
  2866  			}
  2867  
  2868  			bufs = nil
  2869  			c.out.pb = 0
  2870  			c.ws.fs = 0
  2871  			c.ws.frames = nil
  2872  			c.ws.browser = true
  2873  			bufs = append(bufs, []byte("some smaller buffer"))
  2874  			bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers-5))
  2875  			bufs = append(bufs, []byte("then some more"))
  2876  			en = 2 + len(bufs[0])
  2877  			en += 4 + len(bufs[1])
  2878  			en += 2 + len(bufs[2])
  2879  			c.mu.Lock()
  2880  			c.out.nb = bufs
  2881  			res, n = c.collapsePtoNB()
  2882  			c.mu.Unlock()
  2883  			if test.maskingWrite {
  2884  				en += 3 * 4
  2885  			}
  2886  			if n != int64(en) {
  2887  				t.Fatalf("Expected size to be %v, got %v", en, n)
  2888  			}
  2889  			if len(res) != 6 {
  2890  				t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
  2891  			}
  2892  			if len(res[3]) != wsFrameSizeForBrowsers-5 {
  2893  				t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
  2894  			}
  2895  			if test.maskingWrite {
  2896  				key := getKey(res[4])
  2897  				wsMaskBuf(key, res[5])
  2898  			}
  2899  			if string(res[5]) != "then some more" {
  2900  				t.Fatalf("Frame 6 incorrect %q", res[5])
  2901  			}
  2902  
  2903  			bufs = nil
  2904  			c.out.pb = 0
  2905  			c.ws.fs = 0
  2906  			c.ws.frames = nil
  2907  			c.ws.browser = true
  2908  			bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+100))
  2909  			c.mu.Lock()
  2910  			c.out.nb = bufs
  2911  			res, _ = c.collapsePtoNB()
  2912  			c.mu.Unlock()
  2913  			if len(res) != 4 {
  2914  				t.Fatalf("Unexpected number of frames: %v", len(res))
  2915  			}
  2916  		})
  2917  	}
  2918  }
  2919  
  2920  func TestWSWebrowserClient(t *testing.T) {
  2921  	o := testWSOptions()
  2922  	s := RunServer(o)
  2923  	defer s.Shutdown()
  2924  
  2925  	wsc, br := testWSCreateClient(t, false, true, o.Websocket.Host, o.Websocket.Port)
  2926  	defer wsc.Close()
  2927  
  2928  	checkClientsCount(t, s, 1)
  2929  	var c *client
  2930  	s.mu.Lock()
  2931  	for _, cli := range s.clients {
  2932  		c = cli
  2933  		break
  2934  	}
  2935  	s.mu.Unlock()
  2936  
  2937  	proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB foo 1\r\nPING\r\n"))
  2938  	wsc.Write(proto)
  2939  	if res := testWSReadFrame(t, br); !bytes.Equal(res, []byte(pongProto)) {
  2940  		t.Fatalf("Expected PONG back")
  2941  	}
  2942  
  2943  	c.mu.Lock()
  2944  	ok := c.isWebsocket() && c.ws.browser == true
  2945  	c.mu.Unlock()
  2946  	if !ok {
  2947  		t.Fatalf("Client is not marked as webrowser client")
  2948  	}
  2949  
  2950  	nc := natsConnect(t, s.ClientURL())
  2951  	defer nc.Close()
  2952  
  2953  	// Send a big message and check that it is received in smaller frames
  2954  	psize := 204813
  2955  	nc.Publish("foo", make([]byte, psize))
  2956  	nc.Flush()
  2957  
  2958  	rsize := psize + len(fmt.Sprintf("MSG foo %d\r\n\r\n", psize))
  2959  	nframes := 0
  2960  	for total := 0; total < rsize; nframes++ {
  2961  		res := testWSReadFrame(t, br)
  2962  		total += len(res)
  2963  	}
  2964  	if expected := psize / wsFrameSizeForBrowsers; expected > nframes {
  2965  		t.Fatalf("Expected %v frames, got %v", expected, nframes)
  2966  	}
  2967  }
  2968  
  2969  type testWSWrappedConn struct {
  2970  	net.Conn
  2971  	mu      sync.RWMutex
  2972  	buf     *bytes.Buffer
  2973  	partial bool
  2974  }
  2975  
  2976  func (wc *testWSWrappedConn) Write(p []byte) (int, error) {
  2977  	wc.mu.Lock()
  2978  	defer wc.mu.Unlock()
  2979  	var err error
  2980  	n := len(p)
  2981  	if wc.partial && n > 10 {
  2982  		n = 10
  2983  		err = io.ErrShortWrite
  2984  	}
  2985  	p = p[:n]
  2986  	wc.buf.Write(p)
  2987  	wc.Conn.Write(p)
  2988  	return n, err
  2989  }
  2990  
  2991  func TestWSCompressionBasic(t *testing.T) {
  2992  	payload := "This is the content of a message that will be compresseddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd."
  2993  	msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
  2994  	cbuf := &bytes.Buffer{}
  2995  	compressor, err := flate.NewWriter(cbuf, flate.BestSpeed)
  2996  	require_NoError(t, err)
  2997  	compressor.Write([]byte(msgProto))
  2998  	compressor.Flush()
  2999  	compressed := cbuf.Bytes()
  3000  	// The last 4 bytes are dropped
  3001  	compressed = compressed[:len(compressed)-4]
  3002  
  3003  	o := testWSOptions()
  3004  	o.Websocket.Compression = true
  3005  	s := RunServer(o)
  3006  	defer s.Shutdown()
  3007  
  3008  	c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port)
  3009  	defer c.Close()
  3010  
  3011  	proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n"))
  3012  	c.Write(proto)
  3013  	l := testWSReadFrame(t, br)
  3014  	if !bytes.Equal(l, []byte(pongProto)) {
  3015  		t.Fatalf("Expected PONG, got %q", l)
  3016  	}
  3017  
  3018  	var wc *testWSWrappedConn
  3019  	s.mu.RLock()
  3020  	for _, c := range s.clients {
  3021  		c.mu.Lock()
  3022  		wc = &testWSWrappedConn{Conn: c.nc, buf: &bytes.Buffer{}}
  3023  		c.nc = wc
  3024  		c.mu.Unlock()
  3025  	}
  3026  	s.mu.RUnlock()
  3027  
  3028  	nc := natsConnect(t, s.ClientURL())
  3029  	defer nc.Close()
  3030  	natsPub(t, nc, "foo", []byte(payload))
  3031  
  3032  	res := &bytes.Buffer{}
  3033  	for total := 0; total < len(msgProto); {
  3034  		l := testWSReadFrame(t, br)
  3035  		n, _ := res.Write(l)
  3036  		total += n
  3037  	}
  3038  	if !bytes.Equal([]byte(msgProto), res.Bytes()) {
  3039  		t.Fatalf("Unexpected result: %q", res)
  3040  	}
  3041  
  3042  	// Now check the wrapped connection buffer to check that data was actually compressed.
  3043  	wc.mu.RLock()
  3044  	res = wc.buf
  3045  	wc.mu.RUnlock()
  3046  	if bytes.Contains(res.Bytes(), []byte(payload)) {
  3047  		t.Fatalf("Looks like frame was not compressed: %q", res.Bytes())
  3048  	}
  3049  	header := res.Bytes()[:2]
  3050  	body := res.Bytes()[2:]
  3051  	expectedB0 := byte(wsBinaryMessage) | wsFinalBit | wsRsv1Bit
  3052  	expectedPS := len(compressed)
  3053  	expectedB1 := byte(expectedPS)
  3054  
  3055  	if b := header[0]; b != expectedB0 {
  3056  		t.Fatalf("Expected first byte to be %v, got %v", expectedB0, b)
  3057  	}
  3058  	if b := header[1]; b != expectedB1 {
  3059  		t.Fatalf("Expected second byte to be %v, got %v", expectedB1, b)
  3060  	}
  3061  	if len(body) != expectedPS {
  3062  		t.Fatalf("Expected payload length to be %v, got %v", expectedPS, len(body))
  3063  	}
  3064  	if !bytes.Equal(body, compressed) {
  3065  		t.Fatalf("Unexpected compress body: %q", body)
  3066  	}
  3067  
  3068  	wc.mu.Lock()
  3069  	wc.buf.Reset()
  3070  	wc.mu.Unlock()
  3071  
  3072  	payload = "small"
  3073  	natsPub(t, nc, "foo", []byte(payload))
  3074  	msgProto = fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
  3075  	res = &bytes.Buffer{}
  3076  	for total := 0; total < len(msgProto); {
  3077  		l := testWSReadFrame(t, br)
  3078  		n, _ := res.Write(l)
  3079  		total += n
  3080  	}
  3081  	if !bytes.Equal([]byte(msgProto), res.Bytes()) {
  3082  		t.Fatalf("Unexpected result: %q", res)
  3083  	}
  3084  	wc.mu.RLock()
  3085  	res = wc.buf
  3086  	wc.mu.RUnlock()
  3087  	if !bytes.HasSuffix(res.Bytes(), []byte(msgProto)) {
  3088  		t.Fatalf("Looks like frame was compressed: %q", res.Bytes())
  3089  	}
  3090  }
  3091  
  3092  func TestWSCompressionWithPartialWrite(t *testing.T) {
  3093  	payload := "This is the content of a message that will be compresseddddddddddddddddddddd."
  3094  	msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
  3095  
  3096  	o := testWSOptions()
  3097  	o.Websocket.Compression = true
  3098  	s := RunServer(o)
  3099  	defer s.Shutdown()
  3100  
  3101  	c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port)
  3102  	defer c.Close()
  3103  
  3104  	proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n"))
  3105  	c.Write(proto)
  3106  	l := testWSReadFrame(t, br)
  3107  	if !bytes.Equal(l, []byte(pongProto)) {
  3108  		t.Fatalf("Expected PONG, got %q", l)
  3109  	}
  3110  
  3111  	pingPayload := []byte("my ping")
  3112  	pingFromWSClient := testWSCreateClientMsg(wsPingMessage, 1, true, false, pingPayload)
  3113  
  3114  	var wc *testWSWrappedConn
  3115  	var ws *client
  3116  	s.mu.Lock()
  3117  	for _, c := range s.clients {
  3118  		ws = c
  3119  		c.mu.Lock()
  3120  		wc = &testWSWrappedConn{
  3121  			Conn: c.nc,
  3122  			buf:  &bytes.Buffer{},
  3123  		}
  3124  		c.nc = wc
  3125  		c.mu.Unlock()
  3126  		break
  3127  	}
  3128  	s.mu.Unlock()
  3129  
  3130  	wc.mu.Lock()
  3131  	wc.partial = true
  3132  	wc.mu.Unlock()
  3133  
  3134  	nc := natsConnect(t, s.ClientURL())
  3135  	defer nc.Close()
  3136  
  3137  	expected := &bytes.Buffer{}
  3138  	for i := 0; i < 10; i++ {
  3139  		if i > 0 {
  3140  			time.Sleep(10 * time.Millisecond)
  3141  		}
  3142  		expected.Write([]byte(msgProto))
  3143  		natsPub(t, nc, "foo", []byte(payload))
  3144  		if i == 1 {
  3145  			c.Write(pingFromWSClient)
  3146  		}
  3147  	}
  3148  
  3149  	var gotPingResponse bool
  3150  	res := &bytes.Buffer{}
  3151  	for total := 0; total < 10*len(msgProto); {
  3152  		l := testWSReadFrame(t, br)
  3153  		if bytes.Equal(l, pingPayload) {
  3154  			gotPingResponse = true
  3155  		} else {
  3156  			n, _ := res.Write(l)
  3157  			total += n
  3158  		}
  3159  	}
  3160  	if !bytes.Equal(expected.Bytes(), res.Bytes()) {
  3161  		t.Fatalf("Unexpected result: %q", res)
  3162  	}
  3163  	if !gotPingResponse {
  3164  		t.Fatal("Did not get the ping response")
  3165  	}
  3166  
  3167  	checkFor(t, time.Second, 15*time.Millisecond, func() error {
  3168  		ws.mu.Lock()
  3169  		pb := ws.out.pb
  3170  		wf := ws.ws.frames
  3171  		fs := ws.ws.fs
  3172  		ws.mu.Unlock()
  3173  		if pb != 0 || len(wf) != 0 || fs != 0 {
  3174  			return fmt.Errorf("Expected pb, wf and fs to be 0, got %v, %v, %v", pb, wf, fs)
  3175  		}
  3176  		return nil
  3177  	})
  3178  }
  3179  
  3180  func TestWSCompressionFrameSizeLimit(t *testing.T) {
  3181  	for _, test := range []struct {
  3182  		name      string
  3183  		maskWrite bool
  3184  		noLimit   bool
  3185  	}{
  3186  		{"no write masking", false, false},
  3187  		{"write masking", true, false},
  3188  	} {
  3189  		t.Run(test.name, func(t *testing.T) {
  3190  			opts := testWSOptions()
  3191  			opts.MaxPending = MAX_PENDING_SIZE
  3192  			s := &Server{opts: opts}
  3193  			c := &client{srv: s, ws: &websocket{compress: true, browser: true, nocompfrag: test.noLimit, maskwrite: test.maskWrite}}
  3194  			c.initClient()
  3195  
  3196  			uncompressedPayload := make([]byte, 2*wsFrameSizeForBrowsers)
  3197  			for i := 0; i < len(uncompressedPayload); i++ {
  3198  				uncompressedPayload[i] = byte(rand.Intn(256))
  3199  			}
  3200  
  3201  			c.mu.Lock()
  3202  			c.out.nb = append(net.Buffers(nil), uncompressedPayload)
  3203  			nb, _ := c.collapsePtoNB()
  3204  			c.mu.Unlock()
  3205  
  3206  			if test.noLimit && len(nb) != 2 {
  3207  				t.Fatalf("There should be only 2 buffers, the header and payload, got %v", len(nb))
  3208  			}
  3209  
  3210  			bb := &bytes.Buffer{}
  3211  			var key []byte
  3212  			for i, b := range nb {
  3213  				if !test.noLimit {
  3214  					// frame header buffer are always very small. The payload should not be more
  3215  					// than 10 bytes since that is what we passed as the limit.
  3216  					if len(b) > wsFrameSizeForBrowsers {
  3217  						t.Fatalf("Frame size too big: %v (%q)", len(b), b)
  3218  					}
  3219  				}
  3220  				if test.maskWrite {
  3221  					if i%2 == 0 {
  3222  						key = b[len(b)-4:]
  3223  					} else {
  3224  						wsMaskBuf(key, b)
  3225  					}
  3226  				}
  3227  				// Check frame headers for the proper formatting.
  3228  				if i%2 == 0 {
  3229  					// Only the first frame should have the compress bit set.
  3230  					if b[0]&wsRsv1Bit != 0 {
  3231  						if i > 0 {
  3232  							t.Fatalf("Compressed bit should not be in continuation frame")
  3233  						}
  3234  					} else if i == 0 {
  3235  						t.Fatalf("Compressed bit missing")
  3236  					}
  3237  				} else {
  3238  					if test.noLimit {
  3239  						// Since the payload is likely not well compressed, we are expecting
  3240  						// the length to be > wsFrameSizeForBrowsers
  3241  						if len(b) <= wsFrameSizeForBrowsers {
  3242  							t.Fatalf("Expected frame to be bigger, got %v", len(b))
  3243  						}
  3244  					}
  3245  					// Collect the payload
  3246  					bb.Write(b)
  3247  				}
  3248  			}
  3249  			buf := bb.Bytes()
  3250  			buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff)
  3251  			dbr := bytes.NewBuffer(buf)
  3252  			d := flate.NewReader(dbr)
  3253  			uncompressed, err := io.ReadAll(d)
  3254  			if err != nil {
  3255  				t.Fatalf("Error reading frame: %v", err)
  3256  			}
  3257  			if !bytes.Equal(uncompressed, uncompressedPayload) {
  3258  				t.Fatalf("Unexpected uncomressed data: %q", uncompressed)
  3259  			}
  3260  		})
  3261  	}
  3262  }
  3263  
  3264  func TestWSBasicAuth(t *testing.T) {
  3265  	for _, test := range []struct {
  3266  		name    string
  3267  		opts    func() *Options
  3268  		user    string
  3269  		pass    string
  3270  		err     string
  3271  		cookies []string
  3272  	}{
  3273  		{
  3274  			"top level auth, no override, wrong u/p",
  3275  			func() *Options {
  3276  				o := testWSOptions()
  3277  				o.Username = "normal"
  3278  				o.Password = "client"
  3279  				return o
  3280  			},
  3281  			"websocket", "client", "-ERR 'Authorization Violation'",
  3282  			nil,
  3283  		},
  3284  		{
  3285  			"top level auth, no override, correct u/p",
  3286  			func() *Options {
  3287  				o := testWSOptions()
  3288  				o.Username = "normal"
  3289  				o.Password = "client"
  3290  				return o
  3291  			},
  3292  			"normal", "client", "",
  3293  			nil,
  3294  		},
  3295  		{
  3296  			"no top level auth, ws auth, wrong u/p",
  3297  			func() *Options {
  3298  				o := testWSOptions()
  3299  				o.Websocket.Username = "websocket"
  3300  				o.Websocket.Password = "client"
  3301  				return o
  3302  			},
  3303  			"normal", "client", "-ERR 'Authorization Violation'",
  3304  			nil,
  3305  		},
  3306  		{
  3307  			"no top level auth, ws auth, correct u/p",
  3308  			func() *Options {
  3309  				o := testWSOptions()
  3310  				o.Websocket.Username = "websocket"
  3311  				o.Websocket.Password = "client"
  3312  				return o
  3313  			},
  3314  			"websocket", "client", "",
  3315  			nil,
  3316  		},
  3317  		{
  3318  			"top level auth, ws override, wrong u/p",
  3319  			func() *Options {
  3320  				o := testWSOptions()
  3321  				o.Username = "normal"
  3322  				o.Password = "client"
  3323  				o.Websocket.Username = "websocket"
  3324  				o.Websocket.Password = "client"
  3325  				return o
  3326  			},
  3327  			"normal", "client", "-ERR 'Authorization Violation'",
  3328  			nil,
  3329  		},
  3330  		{
  3331  			"top level auth, ws override, correct u/p",
  3332  			func() *Options {
  3333  				o := testWSOptions()
  3334  				o.Username = "normal"
  3335  				o.Password = "client"
  3336  				o.Websocket.Username = "websocket"
  3337  				o.Websocket.Password = "client"
  3338  				return o
  3339  			},
  3340  			"websocket", "client", "",
  3341  			nil,
  3342  		},
  3343  		{
  3344  			"username/password from cookies",
  3345  			func() *Options {
  3346  				o := testWSOptions()
  3347  				o.Websocket.UsernameCookie = "un"
  3348  				o.Websocket.PasswordCookie = "pw"
  3349  				o.Username = "me"
  3350  				o.Password = "s3cr3t!"
  3351  				return o
  3352  			},
  3353  			"", "", "",
  3354  			[]string{"un=me", "pw=s3cr3t!"},
  3355  		},
  3356  		{
  3357  			"bad username/ good password from cookies",
  3358  			func() *Options {
  3359  				o := testWSOptions()
  3360  				o.Websocket.UsernameCookie = "un"
  3361  				o.Websocket.PasswordCookie = "pw"
  3362  				o.Username = "me"
  3363  				o.Password = "s3cr3t!"
  3364  				return o
  3365  			},
  3366  			"", "", "-ERR 'Authorization Violation",
  3367  			[]string{"un=m", "pw=s3cr3t!"},
  3368  		},
  3369  		{
  3370  			"good username/ bad password from cookies",
  3371  			func() *Options {
  3372  				o := testWSOptions()
  3373  				o.Websocket.UsernameCookie = "un"
  3374  				o.Websocket.PasswordCookie = "pw"
  3375  				o.Username = "me"
  3376  				o.Password = "s3cr3t!"
  3377  				return o
  3378  			},
  3379  			"", "", "-ERR 'Authorization Violation",
  3380  			[]string{"un=me", "pw=hi!"},
  3381  		},
  3382  		{
  3383  			"token from cookie",
  3384  			func() *Options {
  3385  				o := testWSOptions()
  3386  				o.Websocket.TokenCookie = "tok"
  3387  				o.Authorization = "l3tm31n!"
  3388  				return o
  3389  			},
  3390  			"", "", "",
  3391  			[]string{"tok=l3tm31n!"},
  3392  		},
  3393  		{
  3394  			"bad token from cookie",
  3395  			func() *Options {
  3396  				o := testWSOptions()
  3397  				o.Websocket.TokenCookie = "tok"
  3398  				o.Authorization = "l3tm31n!"
  3399  				return o
  3400  			},
  3401  			"", "", "-ERR 'Authorization Violation",
  3402  			[]string{"tok=hello!"},
  3403  		},
  3404  	} {
  3405  		t.Run(test.name, func(t *testing.T) {
  3406  			o := test.opts()
  3407  			s := RunServer(o)
  3408  			defer s.Shutdown()
  3409  
  3410  			wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port, test.cookies...)
  3411  			defer wsc.Close()
  3412  
  3413  			connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n",
  3414  				test.user, test.pass)
  3415  
  3416  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3417  			if _, err := wsc.Write(wsmsg); err != nil {
  3418  				t.Fatalf("Error sending message: %v", err)
  3419  			}
  3420  			msg := testWSReadFrame(t, br)
  3421  			if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3422  				t.Fatalf("Expected to receive PONG, got %q", msg)
  3423  			} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
  3424  				t.Fatalf("Expected to receive %q, got %q", test.err, msg)
  3425  			}
  3426  		})
  3427  	}
  3428  }
  3429  
  3430  func TestWSAuthTimeout(t *testing.T) {
  3431  	for _, test := range []struct {
  3432  		name string
  3433  		at   float64
  3434  		wat  float64
  3435  		err  string
  3436  	}{
  3437  		{"use top-level auth timeout", 10.0, 0.0, ""},
  3438  		{"use websocket auth timeout", 10.0, 0.05, "-ERR 'Authentication Timeout'"},
  3439  	} {
  3440  		t.Run(test.name, func(t *testing.T) {
  3441  			o := testWSOptions()
  3442  			o.AuthTimeout = test.at
  3443  			o.Websocket.Username = "websocket"
  3444  			o.Websocket.Password = "client"
  3445  			o.Websocket.AuthTimeout = test.wat
  3446  			s := RunServer(o)
  3447  			defer s.Shutdown()
  3448  
  3449  			wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
  3450  			defer wsc.Close()
  3451  
  3452  			var info serverInfo
  3453  			json.Unmarshal([]byte(l[5:]), &info)
  3454  			// Make sure that we are told that auth is required.
  3455  			if !info.AuthRequired {
  3456  				t.Fatalf("Expected auth required, was not: %q", l)
  3457  			}
  3458  			start := time.Now()
  3459  			// Wait before sending connect
  3460  			time.Sleep(100 * time.Millisecond)
  3461  			connectProto := "CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"websocket\",\"pass\":\"client\"}\r\nPING\r\n"
  3462  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3463  			if _, err := wsc.Write(wsmsg); err != nil {
  3464  				t.Fatalf("Error sending message: %v", err)
  3465  			}
  3466  			msg := testWSReadFrame(t, br)
  3467  			if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
  3468  				t.Fatalf("Expected to receive %q error, got %q", test.err, msg)
  3469  			} else if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3470  				t.Fatalf("Unexpected error: %q", msg)
  3471  			}
  3472  			if dur := time.Since(start); dur > time.Second {
  3473  				t.Fatalf("Too long to get timeout error: %v", dur)
  3474  			}
  3475  		})
  3476  	}
  3477  }
  3478  
  3479  func TestWSTokenAuth(t *testing.T) {
  3480  	for _, test := range []struct {
  3481  		name  string
  3482  		opts  func() *Options
  3483  		token string
  3484  		err   string
  3485  	}{
  3486  		{
  3487  			"top level auth, no override, wrong token",
  3488  			func() *Options {
  3489  				o := testWSOptions()
  3490  				o.Authorization = "goodtoken"
  3491  				return o
  3492  			},
  3493  			"badtoken", "-ERR 'Authorization Violation'",
  3494  		},
  3495  		{
  3496  			"top level auth, no override, correct token",
  3497  			func() *Options {
  3498  				o := testWSOptions()
  3499  				o.Authorization = "goodtoken"
  3500  				return o
  3501  			},
  3502  			"goodtoken", "",
  3503  		},
  3504  		{
  3505  			"no top level auth, ws auth, wrong token",
  3506  			func() *Options {
  3507  				o := testWSOptions()
  3508  				o.Websocket.Token = "goodtoken"
  3509  				return o
  3510  			},
  3511  			"badtoken", "-ERR 'Authorization Violation'",
  3512  		},
  3513  		{
  3514  			"no top level auth, ws auth, correct token",
  3515  			func() *Options {
  3516  				o := testWSOptions()
  3517  				o.Websocket.Token = "goodtoken"
  3518  				return o
  3519  			},
  3520  			"goodtoken", "",
  3521  		},
  3522  		{
  3523  			"top level auth, ws override, wrong token",
  3524  			func() *Options {
  3525  				o := testWSOptions()
  3526  				o.Authorization = "clienttoken"
  3527  				o.Websocket.Token = "websockettoken"
  3528  				return o
  3529  			},
  3530  			"clienttoken", "-ERR 'Authorization Violation'",
  3531  		},
  3532  		{
  3533  			"top level auth, ws override, correct token",
  3534  			func() *Options {
  3535  				o := testWSOptions()
  3536  				o.Authorization = "clienttoken"
  3537  				o.Websocket.Token = "websockettoken"
  3538  				return o
  3539  			},
  3540  			"websockettoken", "",
  3541  		},
  3542  	} {
  3543  		t.Run(test.name, func(t *testing.T) {
  3544  			o := test.opts()
  3545  			s := RunServer(o)
  3546  			defer s.Shutdown()
  3547  
  3548  			wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
  3549  			defer wsc.Close()
  3550  
  3551  			connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"auth_token\":\"%s\"}\r\nPING\r\n",
  3552  				test.token)
  3553  
  3554  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3555  			if _, err := wsc.Write(wsmsg); err != nil {
  3556  				t.Fatalf("Error sending message: %v", err)
  3557  			}
  3558  			msg := testWSReadFrame(t, br)
  3559  			if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3560  				t.Fatalf("Expected to receive PONG, got %q", msg)
  3561  			} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
  3562  				t.Fatalf("Expected to receive %q, got %q", test.err, msg)
  3563  			}
  3564  		})
  3565  	}
  3566  }
  3567  
  3568  func TestWSBindToProperAccount(t *testing.T) {
  3569  	conf := createConfFile(t, []byte(fmt.Sprintf(`
  3570  		listen: "127.0.0.1:-1"
  3571  		accounts {
  3572  			a {
  3573  				users [
  3574  					{user: a, password: pwd, allowed_connection_types: ["%s", "%s"]}
  3575  				]
  3576  			}
  3577  			b {
  3578  				users [
  3579  					{user: b, password: pwd}
  3580  				]
  3581  			}
  3582  		}
  3583  		websocket {
  3584  			listen: "127.0.0.1:-1"
  3585  			no_tls: true
  3586  		}
  3587  	`, jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)))) // on purpose use lower case to ensure that it is converted.
  3588  	s, o := RunServerWithConfig(conf)
  3589  	defer s.Shutdown()
  3590  
  3591  	nc := natsConnect(t, fmt.Sprintf("nats://a:pwd@127.0.0.1:%d", o.Port))
  3592  	defer nc.Close()
  3593  
  3594  	sub := natsSubSync(t, nc, "foo")
  3595  
  3596  	wsc, br, _ := testNewWSClient(t, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port, noTLS: true})
  3597  	// Send CONNECT and PING
  3598  	wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false,
  3599  		[]byte(fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", "a", "pwd")))
  3600  	if _, err := wsc.Write(wsmsg); err != nil {
  3601  		t.Fatalf("Error sending message: %v", err)
  3602  	}
  3603  	// Wait for the PONG
  3604  	if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3605  		t.Fatalf("Expected PONG, got %s", msg)
  3606  	}
  3607  
  3608  	wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n"))
  3609  	if _, err := wsc.Write(wsmsg); err != nil {
  3610  		t.Fatalf("Error sending message: %v", err)
  3611  	}
  3612  
  3613  	natsNexMsg(t, sub, time.Second)
  3614  }
  3615  
  3616  func TestWSUsersAuth(t *testing.T) {
  3617  	users := []*User{{Username: "user", Password: "pwd"}}
  3618  	for _, test := range []struct {
  3619  		name string
  3620  		opts func() *Options
  3621  		user string
  3622  		pass string
  3623  		err  string
  3624  	}{
  3625  		{
  3626  			"no filtering, wrong user",
  3627  			func() *Options {
  3628  				o := testWSOptions()
  3629  				o.Users = users
  3630  				return o
  3631  			},
  3632  			"wronguser", "pwd", "-ERR 'Authorization Violation'",
  3633  		},
  3634  		{
  3635  			"no filtering, correct user",
  3636  			func() *Options {
  3637  				o := testWSOptions()
  3638  				o.Users = users
  3639  				return o
  3640  			},
  3641  			"user", "pwd", "",
  3642  		},
  3643  		{
  3644  			"filering, user not allowed",
  3645  			func() *Options {
  3646  				o := testWSOptions()
  3647  				o.Users = users
  3648  				// Only allowed for regular clients
  3649  				o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard})
  3650  				return o
  3651  			},
  3652  			"user", "pwd", "-ERR 'Authorization Violation'",
  3653  		},
  3654  		{
  3655  			"filtering, user allowed",
  3656  			func() *Options {
  3657  				o := testWSOptions()
  3658  				o.Users = users
  3659  				o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket})
  3660  				return o
  3661  			},
  3662  			"user", "pwd", "",
  3663  		},
  3664  		{
  3665  			"filtering, wrong password",
  3666  			func() *Options {
  3667  				o := testWSOptions()
  3668  				o.Users = users
  3669  				o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket})
  3670  				return o
  3671  			},
  3672  			"user", "badpassword", "-ERR 'Authorization Violation'",
  3673  		},
  3674  	} {
  3675  		t.Run(test.name, func(t *testing.T) {
  3676  			o := test.opts()
  3677  			s := RunServer(o)
  3678  			defer s.Shutdown()
  3679  
  3680  			wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
  3681  			defer wsc.Close()
  3682  
  3683  			connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n",
  3684  				test.user, test.pass)
  3685  
  3686  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3687  			if _, err := wsc.Write(wsmsg); err != nil {
  3688  				t.Fatalf("Error sending message: %v", err)
  3689  			}
  3690  			msg := testWSReadFrame(t, br)
  3691  			if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3692  				t.Fatalf("Expected to receive PONG, got %q", msg)
  3693  			} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
  3694  				t.Fatalf("Expected to receive %q, got %q", test.err, msg)
  3695  			}
  3696  		})
  3697  	}
  3698  }
  3699  
  3700  func TestWSNoAuthUserValidation(t *testing.T) {
  3701  	o := testWSOptions()
  3702  	o.Users = []*User{{Username: "user", Password: "pwd"}}
  3703  	// Should fail because it is not part of o.Users.
  3704  	o.Websocket.NoAuthUser = "notfound"
  3705  	if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") {
  3706  		t.Fatalf("Expected error saying not present as user, got %v", err)
  3707  	}
  3708  	// Set a valid no auth user for global options, but still should fail because
  3709  	// of o.Websocket.NoAuthUser
  3710  	o.NoAuthUser = "user"
  3711  	o.Websocket.NoAuthUser = "notfound"
  3712  	if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") {
  3713  		t.Fatalf("Expected error saying not present as user, got %v", err)
  3714  	}
  3715  }
  3716  
  3717  func TestWSNoAuthUser(t *testing.T) {
  3718  	for _, test := range []struct {
  3719  		name         string
  3720  		override     bool
  3721  		useAuth      bool
  3722  		expectedUser string
  3723  		expectedAcc  string
  3724  	}{
  3725  		{"no override, no user provided", false, false, "noauth", "normal"},
  3726  		{"no override, user povided", false, true, "user", "normal"},
  3727  		{"override, no user provided", true, false, "wsnoauth", "websocket"},
  3728  		{"override, user provided", true, true, "wsuser", "websocket"},
  3729  	} {
  3730  		t.Run(test.name, func(t *testing.T) {
  3731  			o := testWSOptions()
  3732  			normalAcc := NewAccount("normal")
  3733  			websocketAcc := NewAccount("websocket")
  3734  			o.Accounts = []*Account{normalAcc, websocketAcc}
  3735  			o.Users = []*User{
  3736  				{Username: "noauth", Password: "pwd", Account: normalAcc},
  3737  				{Username: "user", Password: "pwd", Account: normalAcc},
  3738  				{Username: "wsnoauth", Password: "pwd", Account: websocketAcc},
  3739  				{Username: "wsuser", Password: "pwd", Account: websocketAcc},
  3740  			}
  3741  			o.NoAuthUser = "noauth"
  3742  			if test.override {
  3743  				o.Websocket.NoAuthUser = "wsnoauth"
  3744  			}
  3745  			s := RunServer(o)
  3746  			defer s.Shutdown()
  3747  
  3748  			wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
  3749  			defer wsc.Close()
  3750  
  3751  			var info serverInfo
  3752  			json.Unmarshal([]byte(l[5:]), &info)
  3753  
  3754  			var connectProto string
  3755  			if test.useAuth {
  3756  				connectProto = fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"pwd\"}\r\nPING\r\n",
  3757  					test.expectedUser)
  3758  			} else {
  3759  				connectProto = "CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"
  3760  			}
  3761  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3762  			if _, err := wsc.Write(wsmsg); err != nil {
  3763  				t.Fatalf("Error sending message: %v", err)
  3764  			}
  3765  			msg := testWSReadFrame(t, br)
  3766  			if !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3767  				t.Fatalf("Unexpected error: %q", msg)
  3768  			}
  3769  
  3770  			c := s.getClient(info.CID)
  3771  			c.mu.Lock()
  3772  			uname := c.opts.Username
  3773  			aname := c.acc.GetName()
  3774  			c.mu.Unlock()
  3775  			if uname != test.expectedUser {
  3776  				t.Fatalf("Expected selected user to be %q, got %q", test.expectedUser, uname)
  3777  			}
  3778  			if aname != test.expectedAcc {
  3779  				t.Fatalf("Expected selected account to be %q, got %q", test.expectedAcc, aname)
  3780  			}
  3781  		})
  3782  	}
  3783  }
  3784  
  3785  func TestWSNkeyAuth(t *testing.T) {
  3786  	nkp, _ := nkeys.CreateUser()
  3787  	pub, _ := nkp.PublicKey()
  3788  
  3789  	wsnkp, _ := nkeys.CreateUser()
  3790  	wspub, _ := wsnkp.PublicKey()
  3791  
  3792  	badkp, _ := nkeys.CreateUser()
  3793  	badpub, _ := badkp.PublicKey()
  3794  
  3795  	for _, test := range []struct {
  3796  		name string
  3797  		opts func() *Options
  3798  		nkey string
  3799  		kp   nkeys.KeyPair
  3800  		err  string
  3801  	}{
  3802  		{
  3803  			"no filtering, wrong nkey",
  3804  			func() *Options {
  3805  				o := testWSOptions()
  3806  				o.Nkeys = []*NkeyUser{{Nkey: pub}}
  3807  				return o
  3808  			},
  3809  			badpub, badkp, "-ERR 'Authorization Violation'",
  3810  		},
  3811  		{
  3812  			"no filtering, correct nkey",
  3813  			func() *Options {
  3814  				o := testWSOptions()
  3815  				o.Nkeys = []*NkeyUser{{Nkey: pub}}
  3816  				return o
  3817  			},
  3818  			pub, nkp, "",
  3819  		},
  3820  		{
  3821  			"filtering, nkey not allowed",
  3822  			func() *Options {
  3823  				o := testWSOptions()
  3824  				o.Nkeys = []*NkeyUser{
  3825  					{
  3826  						Nkey:                   pub,
  3827  						AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}),
  3828  					},
  3829  					{
  3830  						Nkey:                   wspub,
  3831  						AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeWebsocket}),
  3832  					},
  3833  				}
  3834  				return o
  3835  			},
  3836  			pub, nkp, "-ERR 'Authorization Violation'",
  3837  		},
  3838  		{
  3839  			"filtering, correct nkey",
  3840  			func() *Options {
  3841  				o := testWSOptions()
  3842  				o.Nkeys = []*NkeyUser{
  3843  					{Nkey: pub},
  3844  					{
  3845  						Nkey:                   wspub,
  3846  						AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}),
  3847  					},
  3848  				}
  3849  				return o
  3850  			},
  3851  			wspub, wsnkp, "",
  3852  		},
  3853  		{
  3854  			"filtering, wrong nkey",
  3855  			func() *Options {
  3856  				o := testWSOptions()
  3857  				o.Nkeys = []*NkeyUser{
  3858  					{
  3859  						Nkey:                   wspub,
  3860  						AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}),
  3861  					},
  3862  				}
  3863  				return o
  3864  			},
  3865  			badpub, badkp, "-ERR 'Authorization Violation'",
  3866  		},
  3867  	} {
  3868  		t.Run(test.name, func(t *testing.T) {
  3869  			o := test.opts()
  3870  			s := RunServer(o)
  3871  			defer s.Shutdown()
  3872  
  3873  			wsc, br, infoMsg := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
  3874  			defer wsc.Close()
  3875  
  3876  			// Sign Nonce
  3877  			var info nonceInfo
  3878  			json.Unmarshal([]byte(infoMsg[5:]), &info)
  3879  			sigraw, _ := test.kp.Sign([]byte(info.Nonce))
  3880  			sig := base64.RawURLEncoding.EncodeToString(sigraw)
  3881  
  3882  			connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"nkey\":\"%s\",\"sig\":\"%s\"}\r\nPING\r\n", test.nkey, sig)
  3883  
  3884  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3885  			if _, err := wsc.Write(wsmsg); err != nil {
  3886  				t.Fatalf("Error sending message: %v", err)
  3887  			}
  3888  			msg := testWSReadFrame(t, br)
  3889  			if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3890  				t.Fatalf("Expected to receive PONG, got %q", msg)
  3891  			} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
  3892  				t.Fatalf("Expected to receive %q, got %q", test.err, msg)
  3893  			}
  3894  		})
  3895  	}
  3896  }
  3897  
  3898  func TestWSSetHeaderServer(t *testing.T) {
  3899  	o := testWSOptions()
  3900  	o.Websocket.Headers = map[string]string{
  3901  		"X-Custom-Header": "custom-value",
  3902  	}
  3903  
  3904  	s := RunServer(o)
  3905  	defer s.Shutdown()
  3906  
  3907  	opts := testWSClientOptions{
  3908  		host:                 o.Websocket.Host,
  3909  		port:                 o.Websocket.Port,
  3910  		extraResponseHeaders: o.Websocket.Headers,
  3911  	}
  3912  
  3913  	c, _, _ := testNewWSClient(t, opts)
  3914  	defer c.Close()
  3915  }
  3916  
  3917  func TestWSJWTWithAllowedConnectionTypes(t *testing.T) {
  3918  	o := testWSOptions()
  3919  	setupAddTrusted(o)
  3920  	s := RunServer(o)
  3921  	buildMemAccResolver(s)
  3922  	defer s.Shutdown()
  3923  
  3924  	for _, test := range []struct {
  3925  		name            string
  3926  		connectionTypes []string
  3927  		expectedAnswer  string
  3928  	}{
  3929  		{"not allowed", []string{jwt.ConnectionTypeStandard}, "-ERR"},
  3930  		{"allowed", []string{jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)}, "+OK"},
  3931  		{"allowed with unknown", []string{jwt.ConnectionTypeWebsocket, "SomeNewType"}, "+OK"},
  3932  		{"not allowed with unknown", []string{"SomeNewType"}, "-ERR"},
  3933  	} {
  3934  		t.Run(test.name, func(t *testing.T) {
  3935  			nuc := newJWTTestUserClaims()
  3936  			nuc.AllowedConnectionTypes = test.connectionTypes
  3937  			claimOpt := testClaimsOptions{
  3938  				nuc:          nuc,
  3939  				expectAnswer: test.expectedAnswer,
  3940  			}
  3941  			_, c, _, _ := testWSWithClaims(t, s, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port}, claimOpt)
  3942  			c.Close()
  3943  		})
  3944  	}
  3945  }
  3946  
  3947  func TestWSJWTCookieUser(t *testing.T) {
  3948  	nucSigFunc := func() *jwt.UserClaims { return newJWTTestUserClaims() }
  3949  	nucBearerFunc := func() *jwt.UserClaims {
  3950  		ret := newJWTTestUserClaims()
  3951  		ret.BearerToken = true
  3952  		return ret
  3953  	}
  3954  
  3955  	o := testWSOptions()
  3956  	setupAddTrusted(o)
  3957  	setupAddCookie(o)
  3958  	s := RunServer(o)
  3959  	buildMemAccResolver(s)
  3960  	defer s.Shutdown()
  3961  
  3962  	genJwt := func(t *testing.T, nuc *jwt.UserClaims) string {
  3963  		okp, _ := nkeys.FromSeed(oSeed)
  3964  
  3965  		akp, _ := nkeys.CreateAccount()
  3966  		apub, _ := akp.PublicKey()
  3967  
  3968  		nac := jwt.NewAccountClaims(apub)
  3969  		ajwt, err := nac.Encode(okp)
  3970  		if err != nil {
  3971  			t.Fatalf("Error generating account JWT: %v", err)
  3972  		}
  3973  
  3974  		nkp, _ := nkeys.CreateUser()
  3975  		pub, _ := nkp.PublicKey()
  3976  		nuc.Subject = pub
  3977  		jwt, err := nuc.Encode(akp)
  3978  		if err != nil {
  3979  			t.Fatalf("Error generating user JWT: %v", err)
  3980  		}
  3981  		addAccountToMemResolver(s, apub, ajwt)
  3982  		return jwt
  3983  	}
  3984  
  3985  	cliOpts := testWSClientOptions{
  3986  		host: o.Websocket.Host,
  3987  		port: o.Websocket.Port,
  3988  	}
  3989  	for _, test := range []struct {
  3990  		name         string
  3991  		nuc          *jwt.UserClaims
  3992  		opts         func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions)
  3993  		expectAnswer string
  3994  	}{
  3995  		{
  3996  			name: "protocol auth, non-bearer key, with signature",
  3997  			nuc:  nucSigFunc(),
  3998  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  3999  				return cliOpts, testClaimsOptions{nuc: claims}
  4000  			},
  4001  			expectAnswer: "+OK",
  4002  		},
  4003  		{
  4004  			name: "protocol auth, non-bearer key, w/o required signature",
  4005  			nuc:  nucSigFunc(),
  4006  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  4007  				return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
  4008  			},
  4009  			expectAnswer: "-ERR",
  4010  		},
  4011  		{
  4012  			name: "protocol auth, bearer key, w/o signature",
  4013  			nuc:  nucBearerFunc(),
  4014  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  4015  				return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
  4016  			},
  4017  			expectAnswer: "+OK",
  4018  		},
  4019  		{
  4020  			name: "cookie auth, non-bearer key, protocol auth fail",
  4021  			nuc:  nucSigFunc(),
  4022  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  4023  				co := cliOpts
  4024  				co.extraHeaders = map[string][]string{}
  4025  				co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)}
  4026  				return co, testClaimsOptions{connectRequest: struct{}{}}
  4027  			},
  4028  			expectAnswer: "-ERR",
  4029  		},
  4030  		{
  4031  			name: "cookie auth, bearer key, protocol auth success with implied cookie jwt",
  4032  			nuc:  nucBearerFunc(),
  4033  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  4034  				co := cliOpts
  4035  				co.extraHeaders = map[string][]string{}
  4036  				co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)}
  4037  				return co, testClaimsOptions{connectRequest: struct{}{}}
  4038  			},
  4039  			expectAnswer: "+OK",
  4040  		},
  4041  		{
  4042  			name: "cookie auth, non-bearer key, protocol auth success via override jwt in CONNECT opts",
  4043  			nuc:  nucSigFunc(),
  4044  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  4045  				co := cliOpts
  4046  				co.extraHeaders = map[string][]string{}
  4047  				co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)}
  4048  				return co, testClaimsOptions{nuc: nucBearerFunc()}
  4049  			},
  4050  			expectAnswer: "+OK",
  4051  		},
  4052  	} {
  4053  		t.Run(test.name, func(t *testing.T) {
  4054  			cliOpt, claimOpt := test.opts(t, test.nuc)
  4055  			claimOpt.expectAnswer = test.expectAnswer
  4056  			_, c, _, _ := testWSWithClaims(t, s, cliOpt, claimOpt)
  4057  			c.Close()
  4058  		})
  4059  	}
  4060  	s.Shutdown()
  4061  }
  4062  
  4063  func TestWSReloadTLSConfig(t *testing.T) {
  4064  	template := `
  4065  		listen: "127.0.0.1:-1"
  4066  		websocket {
  4067  			listen: "127.0.0.1:-1"
  4068  			tls {
  4069  				cert_file: '%s'
  4070  				key_file: '%s'
  4071  				ca_file: '../test/configs/certs/ca.pem'
  4072  			}
  4073  		}
  4074  	`
  4075  	conf := createConfFile(t, []byte(fmt.Sprintf(template,
  4076  		"../test/configs/certs/server-noip.pem",
  4077  		"../test/configs/certs/server-key-noip.pem")))
  4078  
  4079  	s, o := RunServerWithConfig(conf)
  4080  	defer s.Shutdown()
  4081  
  4082  	addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port)
  4083  	wsc, err := net.Dial("tcp", addr)
  4084  	if err != nil {
  4085  		t.Fatalf("Error creating ws connection: %v", err)
  4086  	}
  4087  	defer wsc.Close()
  4088  
  4089  	tc := &TLSConfigOpts{CaFile: "../test/configs/certs/ca.pem"}
  4090  	tlsConfig, err := GenTLSConfig(tc)
  4091  	if err != nil {
  4092  		t.Fatalf("Error generating TLS config: %v", err)
  4093  	}
  4094  	tlsConfig.ServerName = "127.0.0.1"
  4095  	tlsConfig.RootCAs = tlsConfig.ClientCAs
  4096  	tlsConfig.ClientCAs = nil
  4097  	wsc = tls.Client(wsc, tlsConfig.Clone())
  4098  	if err := wsc.(*tls.Conn).Handshake(); err == nil || !strings.Contains(err.Error(), "SAN") {
  4099  		t.Fatalf("Unexpected error: %v", err)
  4100  	}
  4101  	wsc.Close()
  4102  
  4103  	reloadUpdateConfig(t, s, conf, fmt.Sprintf(template,
  4104  		"../test/configs/certs/server-cert.pem",
  4105  		"../test/configs/certs/server-key.pem"))
  4106  
  4107  	wsc, err = net.Dial("tcp", addr)
  4108  	if err != nil {
  4109  		t.Fatalf("Error creating ws connection: %v", err)
  4110  	}
  4111  	defer wsc.Close()
  4112  
  4113  	wsc = tls.Client(wsc, tlsConfig.Clone())
  4114  	if err := wsc.(*tls.Conn).Handshake(); err != nil {
  4115  		t.Fatalf("Error on TLS handshake: %v", err)
  4116  	}
  4117  }
  4118  
  4119  type captureClientConnectedLogger struct {
  4120  	DummyLogger
  4121  	ch chan string
  4122  }
  4123  
  4124  func (l *captureClientConnectedLogger) Debugf(format string, v ...any) {
  4125  	msg := fmt.Sprintf(format, v...)
  4126  	if !strings.Contains(msg, "Client connection created") {
  4127  		return
  4128  	}
  4129  	select {
  4130  	case l.ch <- msg:
  4131  	default:
  4132  	}
  4133  }
  4134  
  4135  func TestWSXForwardedFor(t *testing.T) {
  4136  	o := testWSOptions()
  4137  	s := RunServer(o)
  4138  	defer s.Shutdown()
  4139  
  4140  	l := &captureClientConnectedLogger{ch: make(chan string, 1)}
  4141  	s.SetLogger(l, true, false)
  4142  
  4143  	for _, test := range []struct {
  4144  		name          string
  4145  		headers       func() map[string][]string
  4146  		useHdrValue   bool
  4147  		expectedValue string
  4148  	}{
  4149  		{"nil map", func() map[string][]string {
  4150  			return nil
  4151  		}, false, _EMPTY_},
  4152  		{"empty map", func() map[string][]string {
  4153  			return make(map[string][]string)
  4154  		}, false, _EMPTY_},
  4155  		{"header present empty value", func() map[string][]string {
  4156  			m := make(map[string][]string)
  4157  			m[wsXForwardedForHeader] = []string{}
  4158  			return m
  4159  		}, false, _EMPTY_},
  4160  		{"header present invalid IP", func() map[string][]string {
  4161  			m := make(map[string][]string)
  4162  			m[wsXForwardedForHeader] = []string{"not a valid IP"}
  4163  			return m
  4164  		}, false, _EMPTY_},
  4165  		{"header present one IP", func() map[string][]string {
  4166  			m := make(map[string][]string)
  4167  			m[wsXForwardedForHeader] = []string{"1.2.3.4"}
  4168  			return m
  4169  		}, true, "1.2.3.4"},
  4170  		{"header present multiple IPs", func() map[string][]string {
  4171  			m := make(map[string][]string)
  4172  			m[wsXForwardedForHeader] = []string{"1.2.3.4", "5.6.7.8"}
  4173  			return m
  4174  		}, true, "1.2.3.4"},
  4175  		{"header present IPv6", func() map[string][]string {
  4176  			m := make(map[string][]string)
  4177  			m[wsXForwardedForHeader] = []string{"::1"}
  4178  			return m
  4179  		}, true, "[::1]"},
  4180  	} {
  4181  		t.Run(test.name, func(t *testing.T) {
  4182  			c, r, _ := testNewWSClient(t, testWSClientOptions{
  4183  				host:         o.Websocket.Host,
  4184  				port:         o.Websocket.Port,
  4185  				extraHeaders: test.headers(),
  4186  			})
  4187  			defer c.Close()
  4188  			// Send CONNECT and PING
  4189  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
  4190  			if _, err := c.Write(wsmsg); err != nil {
  4191  				t.Fatalf("Error sending message: %v", err)
  4192  			}
  4193  			// Wait for the PONG
  4194  			if msg := testWSReadFrame(t, r); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  4195  				t.Fatalf("Expected PONG, got %s", msg)
  4196  			}
  4197  			select {
  4198  			case d := <-l.ch:
  4199  				ipAndSlash := fmt.Sprintf("%s/", test.expectedValue)
  4200  				if test.useHdrValue {
  4201  					if !strings.HasPrefix(d, ipAndSlash) {
  4202  						t.Fatalf("Expected debug statement to start with: %q, got %q", ipAndSlash, d)
  4203  					}
  4204  				} else if strings.HasPrefix(d, ipAndSlash) {
  4205  					t.Fatalf("Unexpected debug statement: %q", d)
  4206  				}
  4207  			case <-time.After(time.Second):
  4208  				t.Fatal("Did not get connect debug statement")
  4209  			}
  4210  		})
  4211  	}
  4212  }
  4213  
  4214  type partialWriteConn struct {
  4215  	net.Conn
  4216  }
  4217  
  4218  func (c *partialWriteConn) Write(b []byte) (int, error) {
  4219  	max := len(b)
  4220  	if max > 0 {
  4221  		max = rand.Intn(max)
  4222  		if max == 0 {
  4223  			max = 1
  4224  		}
  4225  	}
  4226  	n, err := c.Conn.Write(b[:max])
  4227  	if err == nil && max != len(b) {
  4228  		err = io.ErrShortWrite
  4229  	}
  4230  	return n, err
  4231  }
  4232  
  4233  func TestWSWithPartialWrite(t *testing.T) {
  4234  	conf := createConfFile(t, []byte(`
  4235  		listen: "127.0.0.1:-1"
  4236  		websocket {
  4237  			listen: "127.0.0.1:-1"
  4238  			no_tls: true
  4239  		}
  4240  	`))
  4241  	s, o := RunServerWithConfig(conf)
  4242  	defer s.Shutdown()
  4243  
  4244  	nc1 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o.Websocket.Port))
  4245  	defer nc1.Close()
  4246  
  4247  	sub := natsSubSync(t, nc1, "foo")
  4248  	sub.SetPendingLimits(-1, -1)
  4249  	natsFlush(t, nc1)
  4250  
  4251  	nc2 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o.Websocket.Port))
  4252  	defer nc2.Close()
  4253  
  4254  	// Replace websocket connections with ones that will produce short writes.
  4255  	s.mu.RLock()
  4256  	for _, c := range s.clients {
  4257  		c.mu.Lock()
  4258  		c.nc = &partialWriteConn{Conn: c.nc}
  4259  		c.mu.Unlock()
  4260  	}
  4261  	s.mu.RUnlock()
  4262  
  4263  	var msgs [][]byte
  4264  	for i := 0; i < 100; i++ {
  4265  		msg := make([]byte, rand.Intn(10000)+10)
  4266  		for j := 0; j < len(msg); j++ {
  4267  			msg[j] = byte('A' + j%26)
  4268  		}
  4269  		msgs = append(msgs, msg)
  4270  		natsPub(t, nc2, "foo", msg)
  4271  	}
  4272  	for i := 0; i < 100; i++ {
  4273  		rmsg := natsNexMsg(t, sub, time.Second)
  4274  		if !bytes.Equal(msgs[i], rmsg.Data) {
  4275  			t.Fatalf("Expected message %q, got %q", msgs[i], rmsg.Data)
  4276  		}
  4277  	}
  4278  }
  4279  
  4280  func testWSNoCorruptionWithFrameSizeLimit(t *testing.T, total int) {
  4281  	tmpl := `
  4282                 listen: "127.0.0.1:-1"
  4283                 cluster {
  4284                         name: "local"
  4285                         port: -1
  4286                         %s
  4287                 }
  4288                 websocket {
  4289                         listen: "127.0.0.1:-1"
  4290                         no_tls: true
  4291                 }
  4292         `
  4293  	conf1 := createConfFile(t, []byte(fmt.Sprintf(tmpl, _EMPTY_)))
  4294  	s1, o1 := RunServerWithConfig(conf1)
  4295  	defer s1.Shutdown()
  4296  
  4297  	routes := fmt.Sprintf("routes: [\"nats://127.0.0.1:%d\"]", o1.Cluster.Port)
  4298  	conf2 := createConfFile(t, []byte(fmt.Sprintf(tmpl, routes)))
  4299  	s2, o2 := RunServerWithConfig(conf2)
  4300  	defer s2.Shutdown()
  4301  
  4302  	conf3 := createConfFile(t, []byte(fmt.Sprintf(tmpl, routes)))
  4303  	s3, o3 := RunServerWithConfig(conf3)
  4304  	defer s3.Shutdown()
  4305  
  4306  	checkClusterFormed(t, s1, s2, s3)
  4307  
  4308  	nc3 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o3.Websocket.Port))
  4309  	defer nc3.Close()
  4310  
  4311  	nc2 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o2.Websocket.Port))
  4312  	defer nc2.Close()
  4313  
  4314  	payload := make([]byte, 100000)
  4315  	for i := 0; i < len(payload); i++ {
  4316  		payload[i] = 'A' + byte(i%26)
  4317  	}
  4318  	errCh := make(chan error, 1)
  4319  	doneCh := make(chan struct{}, 1)
  4320  	count := int32(0)
  4321  
  4322  	createSub := func(nc *nats.Conn) {
  4323  		sub := natsSub(t, nc, "foo", func(m *nats.Msg) {
  4324  			if !bytes.Equal(m.Data, payload) {
  4325  				stop := len(m.Data)
  4326  				if l := len(payload); l < stop {
  4327  					stop = l
  4328  				}
  4329  				start := 0
  4330  				for i := 0; i < stop; i++ {
  4331  					if m.Data[i] != payload[i] {
  4332  						start = i
  4333  						break
  4334  					}
  4335  				}
  4336  				if stop-start > 20 {
  4337  					stop = start + 20
  4338  				}
  4339  				select {
  4340  				case errCh <- fmt.Errorf("Invalid message: [%d bytes same]%s[...]", start, m.Data[start:stop]):
  4341  				default:
  4342  				}
  4343  				return
  4344  			}
  4345  			if n := atomic.AddInt32(&count, 1); int(n) == 2*total {
  4346  				doneCh <- struct{}{}
  4347  			}
  4348  		})
  4349  		sub.SetPendingLimits(-1, -1)
  4350  	}
  4351  	createSub(nc2)
  4352  	createSub(nc3)
  4353  
  4354  	checkSubInterest(t, s1, globalAccountName, "foo", time.Second)
  4355  
  4356  	nc1 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o1.Websocket.Port))
  4357  	defer nc1.Close()
  4358  	natsFlush(t, nc1)
  4359  
  4360  	// Change websocket connections to force a max frame size.
  4361  	for _, s := range []*Server{s1, s2, s3} {
  4362  		s.mu.RLock()
  4363  		for _, c := range s.clients {
  4364  			c.mu.Lock()
  4365  			if c.ws != nil {
  4366  				c.ws.browser = true
  4367  			}
  4368  			c.mu.Unlock()
  4369  		}
  4370  		s.mu.RUnlock()
  4371  	}
  4372  
  4373  	for i := 0; i < total; i++ {
  4374  		natsPub(t, nc1, "foo", payload)
  4375  		if i%100 == 0 {
  4376  			select {
  4377  			case err := <-errCh:
  4378  				t.Fatalf("Error: %v", err)
  4379  			default:
  4380  			}
  4381  		}
  4382  	}
  4383  	select {
  4384  	case err := <-errCh:
  4385  		t.Fatalf("Error: %v", err)
  4386  	case <-doneCh:
  4387  		return
  4388  	case <-time.After(10 * time.Second):
  4389  		t.Fatalf("Test timed out")
  4390  	}
  4391  }
  4392  
  4393  func TestWSNoCorruptionWithFrameSizeLimit(t *testing.T) {
  4394  	testWSNoCorruptionWithFrameSizeLimit(t, 1000)
  4395  }
  4396  
  4397  // ==================================================================
  4398  // = Benchmark tests
  4399  // ==================================================================
  4400  
  4401  const testWSBenchSubject = "a"
  4402  
  4403  var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()")
  4404  
  4405  func sizedString(sz int) string {
  4406  	b := make([]byte, sz)
  4407  	for i := range b {
  4408  		b[i] = ch[rand.Intn(len(ch))]
  4409  	}
  4410  	return string(b)
  4411  }
  4412  
  4413  func sizedStringForCompression(sz int) string {
  4414  	b := make([]byte, sz)
  4415  	c := byte(0)
  4416  	s := 0
  4417  	for i := range b {
  4418  		if s%20 == 0 {
  4419  			c = ch[rand.Intn(len(ch))]
  4420  		}
  4421  		b[i] = c
  4422  	}
  4423  	return string(b)
  4424  }
  4425  
  4426  func testWSFlushConn(b *testing.B, compress bool, c net.Conn, br *bufio.Reader) {
  4427  	buf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte(pingProto))
  4428  	c.Write(buf)
  4429  	c.SetReadDeadline(time.Now().Add(5 * time.Second))
  4430  	res := testWSReadFrame(b, br)
  4431  	c.SetReadDeadline(time.Time{})
  4432  	if !bytes.HasPrefix(res, []byte(pongProto)) {
  4433  		b.Fatalf("Failed read of PONG: %s\n", res)
  4434  	}
  4435  }
  4436  
  4437  func wsBenchPub(b *testing.B, numPubs int, compress bool, payload string) {
  4438  	b.StopTimer()
  4439  	opts := testWSOptions()
  4440  	opts.Websocket.Compression = compress
  4441  	s := RunServer(opts)
  4442  	defer s.Shutdown()
  4443  
  4444  	extra := 0
  4445  	pubProto := []byte(fmt.Sprintf("PUB %s %d\r\n%s\r\n", testWSBenchSubject, len(payload), payload))
  4446  	singleOpBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, pubProto)
  4447  
  4448  	// Simulate client that would buffer messages before framing/sending.
  4449  	// Figure out how many we can fit in one frame based on b.N and length of pubProto
  4450  	const bufSize = 32768
  4451  	tmpa := [bufSize]byte{}
  4452  	tmp := tmpa[:0]
  4453  	pb := 0
  4454  	for i := 0; i < b.N; i++ {
  4455  		tmp = append(tmp, pubProto...)
  4456  		pb++
  4457  		if len(tmp) >= bufSize {
  4458  			break
  4459  		}
  4460  	}
  4461  	sendBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, tmp)
  4462  	n := b.N / pb
  4463  	extra = b.N - (n * pb)
  4464  
  4465  	wg := sync.WaitGroup{}
  4466  	wg.Add(numPubs)
  4467  
  4468  	type pub struct {
  4469  		c  net.Conn
  4470  		br *bufio.Reader
  4471  		bw *bufio.Writer
  4472  	}
  4473  	var pubs []pub
  4474  	for i := 0; i < numPubs; i++ {
  4475  		wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port)
  4476  		defer wsc.Close()
  4477  		bw := bufio.NewWriterSize(wsc, bufSize)
  4478  		pubs = append(pubs, pub{wsc, br, bw})
  4479  	}
  4480  
  4481  	// Average the amount of bytes sent by iteration
  4482  	avg := len(sendBuf) / pb
  4483  	if extra > 0 {
  4484  		avg += len(singleOpBuf)
  4485  		avg /= 2
  4486  	}
  4487  	b.SetBytes(int64(numPubs * avg))
  4488  	b.StartTimer()
  4489  
  4490  	for i := 0; i < numPubs; i++ {
  4491  		p := pubs[i]
  4492  		go func(p pub) {
  4493  			defer wg.Done()
  4494  			for i := 0; i < n; i++ {
  4495  				p.bw.Write(sendBuf)
  4496  			}
  4497  			for i := 0; i < extra; i++ {
  4498  				p.bw.Write(singleOpBuf)
  4499  			}
  4500  			p.bw.Flush()
  4501  			testWSFlushConn(b, compress, p.c, p.br)
  4502  		}(p)
  4503  	}
  4504  	wg.Wait()
  4505  	b.StopTimer()
  4506  }
  4507  
  4508  func Benchmark_WS_Pubx1_CN_____0b(b *testing.B) {
  4509  	wsBenchPub(b, 1, false, "")
  4510  }
  4511  
  4512  func Benchmark_WS_Pubx1_CY_____0b(b *testing.B) {
  4513  	wsBenchPub(b, 1, true, "")
  4514  }
  4515  
  4516  func Benchmark_WS_Pubx1_CN___128b(b *testing.B) {
  4517  	s := sizedString(128)
  4518  	wsBenchPub(b, 1, false, s)
  4519  }
  4520  
  4521  func Benchmark_WS_Pubx1_CY___128b(b *testing.B) {
  4522  	s := sizedStringForCompression(128)
  4523  	wsBenchPub(b, 1, true, s)
  4524  }
  4525  
  4526  func Benchmark_WS_Pubx1_CN__1024b(b *testing.B) {
  4527  	s := sizedString(1024)
  4528  	wsBenchPub(b, 1, false, s)
  4529  }
  4530  
  4531  func Benchmark_WS_Pubx1_CY__1024b(b *testing.B) {
  4532  	s := sizedStringForCompression(1024)
  4533  	wsBenchPub(b, 1, true, s)
  4534  }
  4535  
  4536  func Benchmark_WS_Pubx1_CN__4096b(b *testing.B) {
  4537  	s := sizedString(4 * 1024)
  4538  	wsBenchPub(b, 1, false, s)
  4539  }
  4540  
  4541  func Benchmark_WS_Pubx1_CY__4096b(b *testing.B) {
  4542  	s := sizedStringForCompression(4 * 1024)
  4543  	wsBenchPub(b, 1, true, s)
  4544  }
  4545  
  4546  func Benchmark_WS_Pubx1_CN__8192b(b *testing.B) {
  4547  	s := sizedString(8 * 1024)
  4548  	wsBenchPub(b, 1, false, s)
  4549  }
  4550  
  4551  func Benchmark_WS_Pubx1_CY__8192b(b *testing.B) {
  4552  	s := sizedStringForCompression(8 * 1024)
  4553  	wsBenchPub(b, 1, true, s)
  4554  }
  4555  
  4556  func Benchmark_WS_Pubx1_CN_32768b(b *testing.B) {
  4557  	s := sizedString(32 * 1024)
  4558  	wsBenchPub(b, 1, false, s)
  4559  }
  4560  
  4561  func Benchmark_WS_Pubx1_CY_32768b(b *testing.B) {
  4562  	s := sizedStringForCompression(32 * 1024)
  4563  	wsBenchPub(b, 1, true, s)
  4564  }
  4565  
  4566  func Benchmark_WS_Pubx5_CN_____0b(b *testing.B) {
  4567  	wsBenchPub(b, 5, false, "")
  4568  }
  4569  
  4570  func Benchmark_WS_Pubx5_CY_____0b(b *testing.B) {
  4571  	wsBenchPub(b, 5, true, "")
  4572  }
  4573  
  4574  func Benchmark_WS_Pubx5_CN___128b(b *testing.B) {
  4575  	s := sizedString(128)
  4576  	wsBenchPub(b, 5, false, s)
  4577  }
  4578  
  4579  func Benchmark_WS_Pubx5_CY___128b(b *testing.B) {
  4580  	s := sizedStringForCompression(128)
  4581  	wsBenchPub(b, 5, true, s)
  4582  }
  4583  
  4584  func Benchmark_WS_Pubx5_CN__1024b(b *testing.B) {
  4585  	s := sizedString(1024)
  4586  	wsBenchPub(b, 5, false, s)
  4587  }
  4588  
  4589  func Benchmark_WS_Pubx5_CY__1024b(b *testing.B) {
  4590  	s := sizedStringForCompression(1024)
  4591  	wsBenchPub(b, 5, true, s)
  4592  }
  4593  
  4594  func Benchmark_WS_Pubx5_CN__4096b(b *testing.B) {
  4595  	s := sizedString(4 * 1024)
  4596  	wsBenchPub(b, 5, false, s)
  4597  }
  4598  
  4599  func Benchmark_WS_Pubx5_CY__4096b(b *testing.B) {
  4600  	s := sizedStringForCompression(4 * 1024)
  4601  	wsBenchPub(b, 5, true, s)
  4602  }
  4603  
  4604  func Benchmark_WS_Pubx5_CN__8192b(b *testing.B) {
  4605  	s := sizedString(8 * 1024)
  4606  	wsBenchPub(b, 5, false, s)
  4607  }
  4608  
  4609  func Benchmark_WS_Pubx5_CY__8192b(b *testing.B) {
  4610  	s := sizedStringForCompression(8 * 1024)
  4611  	wsBenchPub(b, 5, true, s)
  4612  }
  4613  
  4614  func Benchmark_WS_Pubx5_CN_32768b(b *testing.B) {
  4615  	s := sizedString(32 * 1024)
  4616  	wsBenchPub(b, 5, false, s)
  4617  }
  4618  
  4619  func Benchmark_WS_Pubx5_CY_32768b(b *testing.B) {
  4620  	s := sizedStringForCompression(32 * 1024)
  4621  	wsBenchPub(b, 5, true, s)
  4622  }
  4623  
  4624  func wsBenchSub(b *testing.B, numSubs int, compress bool, payload string) {
  4625  	b.StopTimer()
  4626  	opts := testWSOptions()
  4627  	opts.Websocket.Compression = compress
  4628  	s := RunServer(opts)
  4629  	defer s.Shutdown()
  4630  
  4631  	var subs []*bufio.Reader
  4632  	for i := 0; i < numSubs; i++ {
  4633  		wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port)
  4634  		defer wsc.Close()
  4635  		subProto := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress,
  4636  			[]byte(fmt.Sprintf("SUB %s 1\r\nPING\r\n", testWSBenchSubject)))
  4637  		wsc.Write(subProto)
  4638  		// Waiting for PONG
  4639  		testWSReadFrame(b, br)
  4640  		subs = append(subs, br)
  4641  	}
  4642  
  4643  	wg := sync.WaitGroup{}
  4644  	wg.Add(numSubs)
  4645  
  4646  	// Use regular NATS client to publish messages
  4647  	nc := natsConnect(b, s.ClientURL())
  4648  	defer nc.Close()
  4649  
  4650  	b.StartTimer()
  4651  
  4652  	for i := 0; i < numSubs; i++ {
  4653  		br := subs[i]
  4654  		go func(br *bufio.Reader) {
  4655  			defer wg.Done()
  4656  			for count := 0; count < b.N; {
  4657  				msgs := testWSReadFrame(b, br)
  4658  				count += bytes.Count(msgs, []byte("MSG "))
  4659  			}
  4660  		}(br)
  4661  	}
  4662  	for i := 0; i < b.N; i++ {
  4663  		natsPub(b, nc, testWSBenchSubject, []byte(payload))
  4664  	}
  4665  	wg.Wait()
  4666  	b.StopTimer()
  4667  }
  4668  
  4669  func Benchmark_WS_Subx1_CN_____0b(b *testing.B) {
  4670  	wsBenchSub(b, 1, false, "")
  4671  }
  4672  
  4673  func Benchmark_WS_Subx1_CY_____0b(b *testing.B) {
  4674  	wsBenchSub(b, 1, true, "")
  4675  }
  4676  
  4677  func Benchmark_WS_Subx1_CN___128b(b *testing.B) {
  4678  	s := sizedString(128)
  4679  	wsBenchSub(b, 1, false, s)
  4680  }
  4681  
  4682  func Benchmark_WS_Subx1_CY___128b(b *testing.B) {
  4683  	s := sizedStringForCompression(128)
  4684  	wsBenchSub(b, 1, true, s)
  4685  }
  4686  
  4687  func Benchmark_WS_Subx1_CN__1024b(b *testing.B) {
  4688  	s := sizedString(1024)
  4689  	wsBenchSub(b, 1, false, s)
  4690  }
  4691  
  4692  func Benchmark_WS_Subx1_CY__1024b(b *testing.B) {
  4693  	s := sizedStringForCompression(1024)
  4694  	wsBenchSub(b, 1, true, s)
  4695  }
  4696  
  4697  func Benchmark_WS_Subx1_CN__4096b(b *testing.B) {
  4698  	s := sizedString(4096)
  4699  	wsBenchSub(b, 1, false, s)
  4700  }
  4701  
  4702  func Benchmark_WS_Subx1_CY__4096b(b *testing.B) {
  4703  	s := sizedStringForCompression(4096)
  4704  	wsBenchSub(b, 1, true, s)
  4705  }
  4706  
  4707  func Benchmark_WS_Subx1_CN__8192b(b *testing.B) {
  4708  	s := sizedString(8192)
  4709  	wsBenchSub(b, 1, false, s)
  4710  }
  4711  
  4712  func Benchmark_WS_Subx1_CY__8192b(b *testing.B) {
  4713  	s := sizedStringForCompression(8192)
  4714  	wsBenchSub(b, 1, true, s)
  4715  }
  4716  
  4717  func Benchmark_WS_Subx1_CN_32768b(b *testing.B) {
  4718  	s := sizedString(32768)
  4719  	wsBenchSub(b, 1, false, s)
  4720  }
  4721  
  4722  func Benchmark_WS_Subx1_CY_32768b(b *testing.B) {
  4723  	s := sizedStringForCompression(32768)
  4724  	wsBenchSub(b, 1, true, s)
  4725  }
  4726  
  4727  func Benchmark_WS_Subx5_CN_____0b(b *testing.B) {
  4728  	wsBenchSub(b, 5, false, "")
  4729  }
  4730  
  4731  func Benchmark_WS_Subx5_CY_____0b(b *testing.B) {
  4732  	wsBenchSub(b, 5, true, "")
  4733  }
  4734  
  4735  func Benchmark_WS_Subx5_CN___128b(b *testing.B) {
  4736  	s := sizedString(128)
  4737  	wsBenchSub(b, 5, false, s)
  4738  }
  4739  
  4740  func Benchmark_WS_Subx5_CY___128b(b *testing.B) {
  4741  	s := sizedStringForCompression(128)
  4742  	wsBenchSub(b, 5, true, s)
  4743  }
  4744  
  4745  func Benchmark_WS_Subx5_CN__1024b(b *testing.B) {
  4746  	s := sizedString(1024)
  4747  	wsBenchSub(b, 5, false, s)
  4748  }
  4749  
  4750  func Benchmark_WS_Subx5_CY__1024b(b *testing.B) {
  4751  	s := sizedStringForCompression(1024)
  4752  	wsBenchSub(b, 5, true, s)
  4753  }
  4754  
  4755  func Benchmark_WS_Subx5_CN__4096b(b *testing.B) {
  4756  	s := sizedString(4096)
  4757  	wsBenchSub(b, 5, false, s)
  4758  }
  4759  
  4760  func Benchmark_WS_Subx5_CY__4096b(b *testing.B) {
  4761  	s := sizedStringForCompression(4096)
  4762  	wsBenchSub(b, 5, true, s)
  4763  }
  4764  
  4765  func Benchmark_WS_Subx5_CN__8192b(b *testing.B) {
  4766  	s := sizedString(8192)
  4767  	wsBenchSub(b, 5, false, s)
  4768  }
  4769  
  4770  func Benchmark_WS_Subx5_CY__8192b(b *testing.B) {
  4771  	s := sizedStringForCompression(8192)
  4772  	wsBenchSub(b, 5, true, s)
  4773  }
  4774  
  4775  func Benchmark_WS_Subx5_CN_32768b(b *testing.B) {
  4776  	s := sizedString(32768)
  4777  	wsBenchSub(b, 5, false, s)
  4778  }
  4779  
  4780  func Benchmark_WS_Subx5_CY_32768b(b *testing.B) {
  4781  	s := sizedStringForCompression(32768)
  4782  	wsBenchSub(b, 5, true, s)
  4783  }