get.pme.sh/pnats@v0.0.0-20240304004023-26bb5a137ed0/server/websocket_test.go (about)

     1  // Copyright 2020-2024 The NATS Authors
     2  // Licensed under the Apache License, Version 2.0 (the "License");
     3  // you may not use this file except in compliance with the License.
     4  // You may obtain a copy of the License at
     5  //
     6  // http://www.apache.org/licenses/LICENSE-2.0
     7  //
     8  // Unless required by applicable law or agreed to in writing, software
     9  // distributed under the License is distributed on an "AS IS" BASIS,
    10  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    11  // See the License for the specific language governing permissions and
    12  // limitations under the License.
    13  
    14  package server
    15  
    16  import (
    17  	"bufio"
    18  	"bytes"
    19  	"crypto/tls"
    20  	"encoding/base64"
    21  	"encoding/binary"
    22  	"encoding/json"
    23  	"errors"
    24  	"fmt"
    25  	"io"
    26  	"math/rand"
    27  	"net"
    28  	"net/http"
    29  	"net/url"
    30  	"reflect"
    31  	"strings"
    32  	"sync"
    33  	"sync/atomic"
    34  	"testing"
    35  	"time"
    36  
    37  	"github.com/nats-io/jwt/v2"
    38  	"github.com/nats-io/nats.go"
    39  	"github.com/nats-io/nkeys"
    40  
    41  	"github.com/klauspost/compress/flate"
    42  )
    43  
    44  type testReader struct {
    45  	buf []byte
    46  	pos int
    47  	max int
    48  	err error
    49  }
    50  
    51  func (tr *testReader) Read(p []byte) (int, error) {
    52  	if tr.err != nil {
    53  		return 0, tr.err
    54  	}
    55  	n := len(tr.buf) - tr.pos
    56  	if n == 0 {
    57  		return 0, nil
    58  	}
    59  	if n > len(p) {
    60  		n = len(p)
    61  	}
    62  	if tr.max > 0 && n > tr.max {
    63  		n = tr.max
    64  	}
    65  	copy(p, tr.buf[tr.pos:tr.pos+n])
    66  	tr.pos += n
    67  	return n, nil
    68  }
    69  
    70  func TestWSGet(t *testing.T) {
    71  	rb := []byte("012345")
    72  
    73  	tr := &testReader{buf: []byte("6789")}
    74  
    75  	for _, test := range []struct {
    76  		name   string
    77  		pos    int
    78  		needed int
    79  		newpos int
    80  		trmax  int
    81  		result string
    82  		reterr bool
    83  	}{
    84  		{"fromrb1", 0, 3, 3, 4, "012", false},    // Partial from read buffer
    85  		{"fromrb2", 3, 2, 5, 4, "34", false},     // Partial from read buffer
    86  		{"fromrb3", 5, 1, 6, 4, "5", false},      // Partial from read buffer
    87  		{"fromtr1", 4, 4, 6, 4, "4567", false},   // Partial from read buffer + some of ioReader
    88  		{"fromtr2", 4, 6, 6, 4, "456789", false}, // Partial from read buffer + all of ioReader
    89  		{"fromtr3", 4, 6, 6, 2, "456789", false}, // Partial from read buffer + all of ioReader with several reads
    90  		{"fromtr4", 4, 6, 6, 2, "", true},        // ioReader returns error
    91  	} {
    92  		t.Run(test.name, func(t *testing.T) {
    93  			tr.pos = 0
    94  			tr.max = test.trmax
    95  			if test.reterr {
    96  				tr.err = fmt.Errorf("on purpose")
    97  			}
    98  			res, np, err := wsGet(tr, rb, test.pos, test.needed)
    99  			if test.reterr {
   100  				if err == nil {
   101  					t.Fatalf("Expected error, got none")
   102  				}
   103  				if err.Error() != "on purpose" {
   104  					t.Fatalf("Unexpected error: %v", err)
   105  				}
   106  				if np != 0 || res != nil {
   107  					t.Fatalf("Unexpected returned values: res=%v n=%v", res, np)
   108  				}
   109  				return
   110  			}
   111  			if err != nil {
   112  				t.Fatalf("Error on get: %v", err)
   113  			}
   114  			if np != test.newpos {
   115  				t.Fatalf("Expected pos=%v, got %v", test.newpos, np)
   116  			}
   117  			if string(res) != test.result {
   118  				t.Fatalf("Invalid returned content: %s", res)
   119  			}
   120  		})
   121  	}
   122  }
   123  
   124  func TestWSIsControlFrame(t *testing.T) {
   125  	for _, test := range []struct {
   126  		name      string
   127  		code      wsOpCode
   128  		isControl bool
   129  	}{
   130  		{"binary", wsBinaryMessage, false},
   131  		{"text", wsTextMessage, false},
   132  		{"ping", wsPingMessage, true},
   133  		{"pong", wsPongMessage, true},
   134  		{"close", wsCloseMessage, true},
   135  	} {
   136  		t.Run(test.name, func(t *testing.T) {
   137  			if res := wsIsControlFrame(test.code); res != test.isControl {
   138  				t.Fatalf("Expected %q isControl to be %v, got %v", test.name, test.isControl, res)
   139  			}
   140  		})
   141  	}
   142  }
   143  
   144  func testWSSimpleMask(key, buf []byte) {
   145  	for i := 0; i < len(buf); i++ {
   146  		buf[i] ^= key[i&3]
   147  	}
   148  }
   149  
   150  func TestWSUnmask(t *testing.T) {
   151  	key := []byte{1, 2, 3, 4}
   152  	orgBuf := []byte("this is a clear text")
   153  
   154  	mask := func() []byte {
   155  		t.Helper()
   156  		buf := append([]byte(nil), orgBuf...)
   157  		testWSSimpleMask(key, buf)
   158  		// First ensure that the content is masked.
   159  		if bytes.Equal(buf, orgBuf) {
   160  			t.Fatalf("Masking did not do anything: %q", buf)
   161  		}
   162  		return buf
   163  	}
   164  
   165  	ri := &wsReadInfo{mask: true}
   166  	ri.init()
   167  	copy(ri.mkey[:], key)
   168  
   169  	buf := mask()
   170  	// Unmask in one call
   171  	ri.unmask(buf)
   172  	if !bytes.Equal(buf, orgBuf) {
   173  		t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf)
   174  	}
   175  
   176  	// Unmask in multiple calls
   177  	buf = mask()
   178  	ri.mkpos = 0
   179  	ri.unmask(buf[:3])
   180  	ri.unmask(buf[3:11])
   181  	ri.unmask(buf[11:])
   182  	if !bytes.Equal(buf, orgBuf) {
   183  		t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf)
   184  	}
   185  }
   186  
   187  func TestWSCreateCloseMessage(t *testing.T) {
   188  	for _, test := range []struct {
   189  		name      string
   190  		status    int
   191  		psize     int
   192  		truncated bool
   193  	}{
   194  		{"fits", wsCloseStatusInternalSrvError, 10, false},
   195  		{"truncated", wsCloseStatusProtocolError, wsMaxControlPayloadSize + 10, true},
   196  	} {
   197  		t.Run(test.name, func(t *testing.T) {
   198  			payload := make([]byte, test.psize)
   199  			for i := 0; i < len(payload); i++ {
   200  				payload[i] = byte('A' + (i % 26))
   201  			}
   202  			res := wsCreateCloseMessage(test.status, string(payload))
   203  			if status := binary.BigEndian.Uint16(res[:2]); int(status) != test.status {
   204  				t.Fatalf("Expected status to be %v, got %v", test.status, status)
   205  			}
   206  			psize := len(res) - 2
   207  			if !test.truncated {
   208  				if int(psize) != test.psize {
   209  					t.Fatalf("Expected size to be %v, got %v", test.psize, psize)
   210  				}
   211  				if !bytes.Equal(res[2:], payload) {
   212  					t.Fatalf("Unexpected result: %q", res[2:])
   213  				}
   214  				return
   215  			}
   216  			// Since the payload of a close message contains a 2 byte status, the
   217  			// actual max text size will be wsMaxControlPayloadSize-2
   218  			if int(psize) != wsMaxControlPayloadSize-2 {
   219  				t.Fatalf("Expected size to be capped to %v, got %v", wsMaxControlPayloadSize-2, psize)
   220  			}
   221  			if string(res[len(res)-3:]) != "..." {
   222  				t.Fatalf("Expected res to have `...` at the end, got %q", res[4:])
   223  			}
   224  		})
   225  	}
   226  }
   227  
   228  func TestWSCreateFrameHeader(t *testing.T) {
   229  	for _, test := range []struct {
   230  		name       string
   231  		frameType  wsOpCode
   232  		compressed bool
   233  		len        int
   234  	}{
   235  		{"uncompressed 10", wsBinaryMessage, false, 10},
   236  		{"uncompressed 600", wsTextMessage, false, 600},
   237  		{"uncompressed 100000", wsTextMessage, false, 100000},
   238  		{"compressed 10", wsBinaryMessage, true, 10},
   239  		{"compressed 600", wsBinaryMessage, true, 600},
   240  		{"compressed 100000", wsTextMessage, true, 100000},
   241  	} {
   242  		t.Run(test.name, func(t *testing.T) {
   243  			res, _ := wsCreateFrameHeader(false, test.compressed, test.frameType, test.len)
   244  			// The server is always sending the message has a single frame,
   245  			// so the "final" bit should be set.
   246  			expected := byte(test.frameType) | wsFinalBit
   247  			if test.compressed {
   248  				expected |= wsRsv1Bit
   249  			}
   250  			if b := res[0]; b != expected {
   251  				t.Fatalf("Expected first byte to be %v, got %v", expected, b)
   252  			}
   253  			switch {
   254  			case test.len <= 125:
   255  				if len(res) != 2 {
   256  					t.Fatalf("Frame len should be 2, got %v", len(res))
   257  				}
   258  				if res[1] != byte(test.len) {
   259  					t.Fatalf("Expected len to be in second byte and be %v, got %v", test.len, res[1])
   260  				}
   261  			case test.len < 65536:
   262  				// 1+1+2
   263  				if len(res) != 4 {
   264  					t.Fatalf("Frame len should be 4, got %v", len(res))
   265  				}
   266  				if res[1] != 126 {
   267  					t.Fatalf("Second byte value should be 126, got %v", res[1])
   268  				}
   269  				if rl := binary.BigEndian.Uint16(res[2:]); int(rl) != test.len {
   270  					t.Fatalf("Expected len to be %v, got %v", test.len, rl)
   271  				}
   272  			default:
   273  				// 1+1+8
   274  				if len(res) != 10 {
   275  					t.Fatalf("Frame len should be 10, got %v", len(res))
   276  				}
   277  				if res[1] != 127 {
   278  					t.Fatalf("Second byte value should be 127, got %v", res[1])
   279  				}
   280  				if rl := binary.BigEndian.Uint64(res[2:]); int(rl) != test.len {
   281  					t.Fatalf("Expected len to be %v, got %v", test.len, rl)
   282  				}
   283  			}
   284  		})
   285  	}
   286  }
   287  
   288  func testWSCreateClientMsg(frameType wsOpCode, frameNum int, final, compressed bool, payload []byte) []byte {
   289  	if compressed {
   290  		buf := &bytes.Buffer{}
   291  		compressor, _ := flate.NewWriter(buf, 1)
   292  		compressor.Write(payload)
   293  		compressor.Flush()
   294  		payload = buf.Bytes()
   295  		// The last 4 bytes are dropped
   296  		payload = payload[:len(payload)-4]
   297  	}
   298  	frame := make([]byte, 14+len(payload))
   299  	if frameNum == 1 {
   300  		frame[0] = byte(frameType)
   301  	}
   302  	if final {
   303  		frame[0] |= wsFinalBit
   304  	}
   305  	if compressed {
   306  		frame[0] |= wsRsv1Bit
   307  	}
   308  	pos := 1
   309  	lenPayload := len(payload)
   310  	switch {
   311  	case lenPayload <= 125:
   312  		frame[pos] = byte(lenPayload) | wsMaskBit
   313  		pos++
   314  	case lenPayload < 65536:
   315  		frame[pos] = 126 | wsMaskBit
   316  		binary.BigEndian.PutUint16(frame[2:], uint16(lenPayload))
   317  		pos += 3
   318  	default:
   319  		frame[1] = 127 | wsMaskBit
   320  		binary.BigEndian.PutUint64(frame[2:], uint64(lenPayload))
   321  		pos += 9
   322  	}
   323  	key := []byte{1, 2, 3, 4}
   324  	copy(frame[pos:], key)
   325  	pos += 4
   326  	copy(frame[pos:], payload)
   327  	testWSSimpleMask(key, frame[pos:])
   328  	pos += lenPayload
   329  	return frame[:pos]
   330  }
   331  
   332  func testWSSetupForRead() (*client, *wsReadInfo, *testReader) {
   333  	ri := &wsReadInfo{mask: true}
   334  	ri.init()
   335  	tr := &testReader{}
   336  	opts := DefaultOptions()
   337  	opts.MaxPending = MAX_PENDING_SIZE
   338  	s := &Server{opts: opts}
   339  	c := &client{srv: s, ws: &websocket{}}
   340  	c.initClient()
   341  	return c, ri, tr
   342  }
   343  
   344  func TestWSReadUncompressedFrames(t *testing.T) {
   345  	c, ri, tr := testWSSetupForRead()
   346  	// Create 2 WS messages
   347  	pl1 := []byte("first message")
   348  	wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl1)
   349  	pl2 := []byte("second message")
   350  	wsmsg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl2)
   351  	// Add both in single buffer
   352  	orgrb := append([]byte(nil), wsmsg1...)
   353  	orgrb = append(orgrb, wsmsg2...)
   354  
   355  	rb := append([]byte(nil), orgrb...)
   356  	bufs, err := c.wsRead(ri, tr, rb)
   357  	if err != nil {
   358  		t.Fatalf("Unexpected error: %v", err)
   359  	}
   360  	if n := len(bufs); n != 2 {
   361  		t.Fatalf("Expected 2 buffers, got %v", n)
   362  	}
   363  	if !bytes.Equal(bufs[0], pl1) {
   364  		t.Fatalf("Unexpected content for buffer 1: %s", bufs[0])
   365  	}
   366  	if !bytes.Equal(bufs[1], pl2) {
   367  		t.Fatalf("Unexpected content for buffer 2: %s", bufs[1])
   368  	}
   369  
   370  	// Now reset and try with the read buffer not containing full ws frame
   371  	c, ri, tr = testWSSetupForRead()
   372  	rb = append([]byte(nil), orgrb...)
   373  	// Frame is 1+1+4+'first message'. So say we pass with rb of 11 bytes,
   374  	// then we should get "first"
   375  	bufs, err = c.wsRead(ri, tr, rb[:11])
   376  	if err != nil {
   377  		t.Fatalf("Unexpected error: %v", err)
   378  	}
   379  	if n := len(bufs); n != 1 {
   380  		t.Fatalf("Unexpected buffer returned: %v", n)
   381  	}
   382  	if string(bufs[0]) != "first" {
   383  		t.Fatalf("Unexpected content: %q", bufs[0])
   384  	}
   385  	// Call again with more data..
   386  	bufs, err = c.wsRead(ri, tr, rb[11:32])
   387  	if err != nil {
   388  		t.Fatalf("Unexpected error: %v", err)
   389  	}
   390  	if n := len(bufs); n != 2 {
   391  		t.Fatalf("Unexpected buffer returned: %v", n)
   392  	}
   393  	if string(bufs[0]) != " message" {
   394  		t.Fatalf("Unexpected content: %q", bufs[0])
   395  	}
   396  	if string(bufs[1]) != "second " {
   397  		t.Fatalf("Unexpected content: %q", bufs[1])
   398  	}
   399  	// Call with the rest
   400  	bufs, err = c.wsRead(ri, tr, rb[32:])
   401  	if err != nil {
   402  		t.Fatalf("Unexpected error: %v", err)
   403  	}
   404  	if n := len(bufs); n != 1 {
   405  		t.Fatalf("Unexpected buffer returned: %v", n)
   406  	}
   407  	if string(bufs[0]) != "message" {
   408  		t.Fatalf("Unexpected content: %q", bufs[0])
   409  	}
   410  }
   411  
   412  func TestWSReadCompressedFrames(t *testing.T) {
   413  	c, ri, tr := testWSSetupForRead()
   414  	uncompressed := []byte("this is the uncompress data")
   415  	wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed)
   416  	rb := append([]byte(nil), wsmsg1...)
   417  	// Call with some but not all of the payload
   418  	bufs, err := c.wsRead(ri, tr, rb[:10])
   419  	if err != nil {
   420  		t.Fatalf("Unexpected error: %v", err)
   421  	}
   422  	if n := len(bufs); n != 0 {
   423  		t.Fatalf("Unexpected buffer returned: %v", n)
   424  	}
   425  	// Call with the rest, only then should we get the uncompressed data.
   426  	bufs, err = c.wsRead(ri, tr, rb[10:])
   427  	if err != nil {
   428  		t.Fatalf("Unexpected error: %v", err)
   429  	}
   430  	if n := len(bufs); n != 1 {
   431  		t.Fatalf("Unexpected buffer returned: %v", n)
   432  	}
   433  	if !bytes.Equal(bufs[0], uncompressed) {
   434  		t.Fatalf("Unexpected content: %s", bufs[0])
   435  	}
   436  	// Stress the fact that we use a pool and want to make sure
   437  	// that if we get a decompressor from the pool, it is properly reset
   438  	// with the buffer to decompress.
   439  	// Since we unmask the read buffer, reset it now and fill it
   440  	// with 10 compressed frames.
   441  	rb = nil
   442  	for i := 0; i < 10; i++ {
   443  		rb = append(rb, wsmsg1...)
   444  	}
   445  	bufs, err = c.wsRead(ri, tr, rb)
   446  	if err != nil {
   447  		t.Fatalf("Unexpected error: %v", err)
   448  	}
   449  	if n := len(bufs); n != 10 {
   450  		t.Fatalf("Unexpected buffer returned: %v", n)
   451  	}
   452  
   453  	// Compress a message and send it in several frames.
   454  	buf := &bytes.Buffer{}
   455  	compressor, _ := flate.NewWriter(buf, 1)
   456  	compressor.Write(uncompressed)
   457  	compressor.Flush()
   458  	compressed := buf.Bytes()
   459  	// The last 4 bytes are dropped
   460  	compressed = compressed[:len(compressed)-4]
   461  	ncomp := 10
   462  	frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, compressed[:ncomp])
   463  	frag1[0] |= wsRsv1Bit
   464  	frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, compressed[ncomp:])
   465  	rb = append([]byte(nil), frag1...)
   466  	rb = append(rb, frag2...)
   467  	bufs, err = c.wsRead(ri, tr, rb)
   468  	if err != nil {
   469  		t.Fatalf("Unexpected error: %v", err)
   470  	}
   471  	if n := len(bufs); n != 1 {
   472  		t.Fatalf("Unexpected buffer returned: %v", n)
   473  	}
   474  	if !bytes.Equal(bufs[0], uncompressed) {
   475  		t.Fatalf("Unexpected content: %s", bufs[0])
   476  	}
   477  }
   478  
   479  func TestWSReadCompressedFrameCorrupted(t *testing.T) {
   480  	c, ri, tr := testWSSetupForRead()
   481  	uncompressed := []byte("this is the uncompress data")
   482  	wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed)
   483  	copy(wsmsg1[10:], []byte{1, 2, 3, 4})
   484  	rb := append([]byte(nil), wsmsg1...)
   485  	bufs, err := c.wsRead(ri, tr, rb)
   486  	if err == nil || !strings.Contains(err.Error(), "corrupt") {
   487  		t.Fatalf("Expected error about corrupted data, got %v", err)
   488  	}
   489  	if n := len(bufs); n != 0 {
   490  		t.Fatalf("Expected no buffer, got %v", n)
   491  	}
   492  }
   493  
   494  func TestWSReadVariousFrameSizes(t *testing.T) {
   495  	for _, test := range []struct {
   496  		name string
   497  		size int
   498  	}{
   499  		{"tiny", 100},
   500  		{"medium", 1000},
   501  		{"large", 70000},
   502  	} {
   503  		t.Run(test.name, func(t *testing.T) {
   504  			c, ri, tr := testWSSetupForRead()
   505  			uncompressed := make([]byte, test.size)
   506  			for i := 0; i < len(uncompressed); i++ {
   507  				uncompressed[i] = 'A' + byte(i%26)
   508  			}
   509  			wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, uncompressed)
   510  			rb := append([]byte(nil), wsmsg1...)
   511  			bufs, err := c.wsRead(ri, tr, rb)
   512  			if err != nil {
   513  				t.Fatalf("Unexpected error: %v", err)
   514  			}
   515  			if n := len(bufs); n != 1 {
   516  				t.Fatalf("Unexpected buffer returned: %v", n)
   517  			}
   518  			if !bytes.Equal(bufs[0], uncompressed) {
   519  				t.Fatalf("Unexpected content: %s", bufs[0])
   520  			}
   521  		})
   522  	}
   523  }
   524  
   525  func TestWSReadFragmentedFrames(t *testing.T) {
   526  	c, ri, tr := testWSSetupForRead()
   527  	payloads := []string{"first", "second", "third"}
   528  	var rb []byte
   529  	for i := 0; i < len(payloads); i++ {
   530  		final := i == len(payloads)-1
   531  		frag := testWSCreateClientMsg(wsBinaryMessage, i+1, final, false, []byte(payloads[i]))
   532  		rb = append(rb, frag...)
   533  	}
   534  	bufs, err := c.wsRead(ri, tr, rb)
   535  	if err != nil {
   536  		t.Fatalf("Unexpected error: %v", err)
   537  	}
   538  	if n := len(bufs); n != 3 {
   539  		t.Fatalf("Unexpected buffer returned: %v", n)
   540  	}
   541  	for i, expected := range payloads {
   542  		if string(bufs[i]) != expected {
   543  			t.Fatalf("Unexpected content for buf=%v: %s", i, bufs[i])
   544  		}
   545  	}
   546  }
   547  
   548  func TestWSReadPartialFrameHeaderAtEndOfReadBuffer(t *testing.T) {
   549  	c, ri, tr := testWSSetupForRead()
   550  	msg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg1"))
   551  	msg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg2"))
   552  	rb := append([]byte(nil), msg1...)
   553  	rb = append(rb, msg2...)
   554  	// We will pass the first frame + the first byte of the next frame.
   555  	rbl := rb[:len(msg1)+1]
   556  	// Make the io reader return the rest of the frame
   557  	tr.buf = rb[len(msg1)+1:]
   558  	bufs, err := c.wsRead(ri, tr, rbl)
   559  	if err != nil {
   560  		t.Fatalf("Unexpected error: %v", err)
   561  	}
   562  	if n := len(bufs); n != 1 {
   563  		t.Fatalf("Unexpected buffer returned: %v", n)
   564  	}
   565  	// We should not have asked to the io reader more than what is needed for reading
   566  	// the frame header. Since we had already the first byte in the read buffer,
   567  	// tr.pos should be 1(size)+4(key)=5
   568  	if tr.pos != 5 {
   569  		t.Fatalf("Expected reader pos to be 5, got %v", tr.pos)
   570  	}
   571  }
   572  
   573  func TestWSReadPingFrame(t *testing.T) {
   574  	for _, test := range []struct {
   575  		name    string
   576  		payload []byte
   577  	}{
   578  		{"without payload", nil},
   579  		{"with payload", []byte("optional payload")},
   580  	} {
   581  		t.Run(test.name, func(t *testing.T) {
   582  			c, ri, tr := testWSSetupForRead()
   583  			ping := testWSCreateClientMsg(wsPingMessage, 1, true, false, test.payload)
   584  			rb := append([]byte(nil), ping...)
   585  			bufs, err := c.wsRead(ri, tr, rb)
   586  			if err != nil {
   587  				t.Fatalf("Unexpected error: %v", err)
   588  			}
   589  			if n := len(bufs); n != 0 {
   590  				t.Fatalf("Unexpected buffer returned: %v", n)
   591  			}
   592  			// A PONG should have been queued with the payload of the ping
   593  			c.mu.Lock()
   594  			nb, _ := c.collapsePtoNB()
   595  			c.mu.Unlock()
   596  			if n := len(nb); n == 0 {
   597  				t.Fatalf("Expected buffers, got %v", n)
   598  			}
   599  			if expected := 2 + len(test.payload); expected != len(nb[0]) {
   600  				t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
   601  			}
   602  			b := nb[0][0]
   603  			if b&wsFinalBit == 0 {
   604  				t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
   605  			}
   606  			if b&byte(wsPongMessage) == 0 {
   607  				t.Fatalf("Should have been a PONG, it wasn't: %v", b)
   608  			}
   609  			if len(test.payload) > 0 {
   610  				if !bytes.Equal(nb[0][2:], test.payload) {
   611  					t.Fatalf("Unexpected content: %s", nb[0][2:])
   612  				}
   613  			}
   614  		})
   615  	}
   616  }
   617  
   618  func TestWSReadPongFrame(t *testing.T) {
   619  	for _, test := range []struct {
   620  		name    string
   621  		payload []byte
   622  	}{
   623  		{"without payload", nil},
   624  		{"with payload", []byte("optional payload")},
   625  	} {
   626  		t.Run(test.name, func(t *testing.T) {
   627  			c, ri, tr := testWSSetupForRead()
   628  			pong := testWSCreateClientMsg(wsPongMessage, 1, true, false, test.payload)
   629  			rb := append([]byte(nil), pong...)
   630  			bufs, err := c.wsRead(ri, tr, rb)
   631  			if err != nil {
   632  				t.Fatalf("Unexpected error: %v", err)
   633  			}
   634  			if n := len(bufs); n != 0 {
   635  				t.Fatalf("Unexpected buffer returned: %v", n)
   636  			}
   637  			// Nothing should be sent...
   638  			c.mu.Lock()
   639  			nb, _ := c.collapsePtoNB()
   640  			c.mu.Unlock()
   641  			if n := len(nb); n != 0 {
   642  				t.Fatalf("Expected no buffer, got %v", n)
   643  			}
   644  		})
   645  	}
   646  }
   647  
   648  func TestWSReadCloseFrame(t *testing.T) {
   649  	for _, test := range []struct {
   650  		name    string
   651  		payload []byte
   652  	}{
   653  		{"without payload", nil},
   654  		{"with payload", []byte("optional payload")},
   655  	} {
   656  		t.Run(test.name, func(t *testing.T) {
   657  			c, ri, tr := testWSSetupForRead()
   658  			// a close message has a status in 2 bytes + optional payload
   659  			payload := make([]byte, 2+len(test.payload))
   660  			binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure)
   661  			if len(test.payload) > 0 {
   662  				copy(payload[2:], test.payload)
   663  			}
   664  			close := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload)
   665  			// Have a normal frame prior to close to make sure that wsRead returns
   666  			// the normal frame along with io.EOF to indicate that wsCloseMessage was received.
   667  			msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg"))
   668  			rb := append([]byte(nil), msg...)
   669  			rb = append(rb, close...)
   670  			bufs, err := c.wsRead(ri, tr, rb)
   671  			// It is expected that wsRead returns io.EOF on processing a close.
   672  			if err != io.EOF {
   673  				t.Fatalf("Unexpected error: %v", err)
   674  			}
   675  			if n := len(bufs); n != 1 {
   676  				t.Fatalf("Unexpected buffer returned: %v", n)
   677  			}
   678  			if string(bufs[0]) != "msg" {
   679  				t.Fatalf("Unexpected content: %s", bufs[0])
   680  			}
   681  			// A CLOSE should have been queued with the payload of the original close message.
   682  			c.mu.Lock()
   683  			nb, _ := c.collapsePtoNB()
   684  			c.mu.Unlock()
   685  			if n := len(nb); n == 0 {
   686  				t.Fatalf("Expected buffers, got %v", n)
   687  			}
   688  			if expected := 2 + 2 + len(test.payload); expected != len(nb[0]) {
   689  				t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
   690  			}
   691  			b := nb[0][0]
   692  			if b&wsFinalBit == 0 {
   693  				t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
   694  			}
   695  			if b&byte(wsCloseMessage) == 0 {
   696  				t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
   697  			}
   698  			if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure {
   699  				t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status)
   700  			}
   701  			if len(test.payload) > 0 {
   702  				if !bytes.Equal(nb[0][4:], test.payload) {
   703  					t.Fatalf("Unexpected content: %s", nb[0][4:])
   704  				}
   705  			}
   706  		})
   707  	}
   708  }
   709  
   710  func TestWSReadControlFrameBetweebFragmentedFrames(t *testing.T) {
   711  	c, ri, tr := testWSSetupForRead()
   712  	frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("first"))
   713  	frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, []byte("second"))
   714  	ctrl := testWSCreateClientMsg(wsPongMessage, 1, true, false, nil)
   715  	rb := append([]byte(nil), frag1...)
   716  	rb = append(rb, ctrl...)
   717  	rb = append(rb, frag2...)
   718  	bufs, err := c.wsRead(ri, tr, rb)
   719  	if err != nil {
   720  		t.Fatalf("Unexpected error: %v", err)
   721  	}
   722  	if n := len(bufs); n != 2 {
   723  		t.Fatalf("Unexpected buffer returned: %v", n)
   724  	}
   725  	if string(bufs[0]) != "first" {
   726  		t.Fatalf("Unexpected content: %s", bufs[0])
   727  	}
   728  	if string(bufs[1]) != "second" {
   729  		t.Fatalf("Unexpected content: %s", bufs[1])
   730  	}
   731  }
   732  
   733  func TestWSCloseFrameWithPartialOrInvalid(t *testing.T) {
   734  	c, ri, tr := testWSSetupForRead()
   735  	// a close message has a status in 2 bytes + optional payload
   736  	payloadTxt := []byte("hello")
   737  	payload := make([]byte, 2+len(payloadTxt))
   738  	binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure)
   739  	copy(payload[2:], payloadTxt)
   740  	closeMsg := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload)
   741  
   742  	// We will pass to wsRead a buffer of small capacity that contains
   743  	// only 1 byte.
   744  	closeFirtByte := []byte{closeMsg[0]}
   745  	// Make the io reader return the rest of the frame
   746  	tr.buf = closeMsg[1:]
   747  	bufs, err := c.wsRead(ri, tr, closeFirtByte[:])
   748  	// It is expected that wsRead returns io.EOF on processing a close.
   749  	if err != io.EOF {
   750  		t.Fatalf("Unexpected error: %v", err)
   751  	}
   752  	if n := len(bufs); n != 0 {
   753  		t.Fatalf("Unexpected buffer returned: %v", n)
   754  	}
   755  	// A CLOSE should have been queued with the payload of the original close message.
   756  	c.mu.Lock()
   757  	nb, _ := c.collapsePtoNB()
   758  	c.mu.Unlock()
   759  	if n := len(nb); n == 0 {
   760  		t.Fatalf("Expected buffers, got %v", n)
   761  	}
   762  	if expected := 2 + 2 + len(payloadTxt); expected != len(nb[0]) {
   763  		t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
   764  	}
   765  	b := nb[0][0]
   766  	if b&wsFinalBit == 0 {
   767  		t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
   768  	}
   769  	if b&byte(wsCloseMessage) == 0 {
   770  		t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
   771  	}
   772  	if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure {
   773  		t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status)
   774  	}
   775  	if !bytes.Equal(nb[0][4:], payloadTxt) {
   776  		t.Fatalf("Unexpected content: %s", nb[0][4:])
   777  	}
   778  
   779  	// Now test close with invalid status size (1 instead of 2 bytes)
   780  	c, ri, tr = testWSSetupForRead()
   781  	payload[0] = 100
   782  	binary.BigEndian.PutUint16(payload, wsCloseStatusNormalClosure)
   783  	closeMsg = testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload[:1])
   784  
   785  	// We will pass to wsRead a buffer of small capacity that contains
   786  	// only 1 byte.
   787  	closeFirtByte = []byte{closeMsg[0]}
   788  	// Make the io reader return the rest of the frame
   789  	tr.buf = closeMsg[1:]
   790  	bufs, err = c.wsRead(ri, tr, closeFirtByte[:])
   791  	// It is expected that wsRead returns io.EOF on processing a close.
   792  	if err != io.EOF {
   793  		t.Fatalf("Unexpected error: %v", err)
   794  	}
   795  	if n := len(bufs); n != 0 {
   796  		t.Fatalf("Unexpected buffer returned: %v", n)
   797  	}
   798  	// A CLOSE should have been queued with the payload of the original close message.
   799  	c.mu.Lock()
   800  	nb, _ = c.collapsePtoNB()
   801  	c.mu.Unlock()
   802  	if n := len(nb); n == 0 {
   803  		t.Fatalf("Expected buffers, got %v", n)
   804  	}
   805  	if expected := 2 + 2; expected != len(nb[0]) {
   806  		t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0]))
   807  	}
   808  	b = nb[0][0]
   809  	if b&wsFinalBit == 0 {
   810  		t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
   811  	}
   812  	if b&byte(wsCloseMessage) == 0 {
   813  		t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
   814  	}
   815  	// Since satus was not valid, we should get wsCloseStatusNoStatusReceived
   816  	if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNoStatusReceived {
   817  		t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNoStatusReceived, status)
   818  	}
   819  	if len(nb[0][:]) != 4 {
   820  		t.Fatalf("Unexpected content: %s", nb[0][2:])
   821  	}
   822  }
   823  
   824  func TestWSReadGetErrors(t *testing.T) {
   825  	tr := &testReader{err: fmt.Errorf("on purpose")}
   826  	for _, test := range []struct {
   827  		lenPayload int
   828  		rbextra    int
   829  	}{
   830  		{10, 1},
   831  		{10, 3},
   832  		{200, 1},
   833  		{200, 2},
   834  		{200, 5},
   835  		{70000, 1},
   836  		{70000, 5},
   837  		{70000, 13},
   838  	} {
   839  		t.Run("", func(t *testing.T) {
   840  			c, ri, _ := testWSSetupForRead()
   841  			msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg"))
   842  			frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, make([]byte, test.lenPayload))
   843  			rb := append([]byte(nil), msg...)
   844  			rb = append(rb, frame...)
   845  			bufs, err := c.wsRead(ri, tr, rb[:len(msg)+test.rbextra])
   846  			if err == nil || err.Error() != "on purpose" {
   847  				t.Fatalf("Expected 'on purpose' error, got %v", err)
   848  			}
   849  			if n := len(bufs); n != 1 {
   850  				t.Fatalf("Unexpected buffer returned: %v", n)
   851  			}
   852  			if string(bufs[0]) != "msg" {
   853  				t.Fatalf("Unexpected content: %s", bufs[0])
   854  			}
   855  		})
   856  	}
   857  }
   858  
   859  func TestWSHandleControlFrameErrors(t *testing.T) {
   860  	c, ri, tr := testWSSetupForRead()
   861  	tr.err = fmt.Errorf("on purpose")
   862  
   863  	// a close message has a status in 2 bytes + optional payload
   864  	text := []byte("this is a close message")
   865  	payload := make([]byte, 2+len(text))
   866  	binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure)
   867  	copy(payload[2:], text)
   868  	ctrl := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload)
   869  
   870  	bufs, err := c.wsRead(ri, tr, ctrl[:len(ctrl)-4])
   871  	if err == nil || err.Error() != "on purpose" {
   872  		t.Fatalf("Expected 'on purpose' error, got %v", err)
   873  	}
   874  	if n := len(bufs); n != 0 {
   875  		t.Fatalf("Unexpected buffer returned: %v", n)
   876  	}
   877  
   878  	// Alter the content of close message. It is supposed to be valid utf-8.
   879  	c, ri, tr = testWSSetupForRead()
   880  	cp := append([]byte(nil), payload...)
   881  	cp[10] = 0xF1
   882  	ctrl = testWSCreateClientMsg(wsCloseMessage, 1, true, false, cp)
   883  	bufs, err = c.wsRead(ri, tr, ctrl)
   884  	// We should still receive an EOF but the message enqueued to the client
   885  	// should contain wsCloseStatusInvalidPayloadData and the error about invalid utf8
   886  	if err != io.EOF {
   887  		t.Fatalf("Unexpected error: %v", err)
   888  	}
   889  	if n := len(bufs); n != 0 {
   890  		t.Fatalf("Unexpected buffer returned: %v", n)
   891  	}
   892  	c.mu.Lock()
   893  	nb, _ := c.collapsePtoNB()
   894  	c.mu.Unlock()
   895  	if n := len(nb); n == 0 {
   896  		t.Fatalf("Expected buffers, got %v", n)
   897  	}
   898  	b := nb[0][0]
   899  	if b&wsFinalBit == 0 {
   900  		t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
   901  	}
   902  	if b&byte(wsCloseMessage) == 0 {
   903  		t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
   904  	}
   905  	if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusInvalidPayloadData {
   906  		t.Fatalf("Expected status to be %v, got %v", wsCloseStatusInvalidPayloadData, status)
   907  	}
   908  	if !bytes.Contains(nb[0][4:], []byte("utf8")) {
   909  		t.Fatalf("Unexpected content: %s", nb[0][4:])
   910  	}
   911  }
   912  
   913  func TestWSReadErrors(t *testing.T) {
   914  	for _, test := range []struct {
   915  		cframe func() []byte
   916  		err    string
   917  		nbufs  int
   918  	}{
   919  		{
   920  			func() []byte {
   921  				msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello"))
   922  				msg[1] &= ^byte(wsMaskBit)
   923  				return msg
   924  			},
   925  			"mask bit missing", 1,
   926  		},
   927  		{
   928  			func() []byte {
   929  				return testWSCreateClientMsg(wsPingMessage, 1, true, false, make([]byte, 200))
   930  			},
   931  			"control frame length bigger than maximum allowed", 1,
   932  		},
   933  		{
   934  			func() []byte {
   935  				return testWSCreateClientMsg(wsPingMessage, 1, false, false, []byte("hello"))
   936  			},
   937  			"control frame does not have final bit set", 1,
   938  		},
   939  		{
   940  			func() []byte {
   941  				frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("frag1"))
   942  				newMsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("new message"))
   943  				all := append([]byte(nil), frag1...)
   944  				all = append(all, newMsg...)
   945  				return all
   946  			},
   947  			"new message started before final frame for previous message was received", 2,
   948  		},
   949  		{
   950  			func() []byte {
   951  				frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("frame"))
   952  				frag := testWSCreateClientMsg(wsBinaryMessage, 2, false, false, []byte("continuation"))
   953  				all := append([]byte(nil), frame...)
   954  				all = append(all, frag...)
   955  				return all
   956  			},
   957  			"invalid continuation frame", 2,
   958  		},
   959  		{
   960  			func() []byte {
   961  				return testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frame"))
   962  			},
   963  			"invalid continuation frame", 1,
   964  		},
   965  		{
   966  			func() []byte {
   967  				return testWSCreateClientMsg(99, 1, false, false, []byte("hello"))
   968  			},
   969  			"unknown opcode", 1,
   970  		},
   971  	} {
   972  		t.Run(test.err, func(t *testing.T) {
   973  			c, ri, tr := testWSSetupForRead()
   974  			// Add a valid message first
   975  			msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello"))
   976  			// Then add the bad frame
   977  			bad := test.cframe()
   978  			// Add them both to a read buffer
   979  			rb := append([]byte(nil), msg...)
   980  			rb = append(rb, bad...)
   981  			bufs, err := c.wsRead(ri, tr, rb)
   982  			if err == nil || !strings.Contains(err.Error(), test.err) {
   983  				t.Fatalf("Expected error to contain %q, got %q", test.err, err.Error())
   984  			}
   985  			if n := len(bufs); n != test.nbufs {
   986  				t.Fatalf("Unexpected number of buffers: %v", n)
   987  			}
   988  			if string(bufs[0]) != "hello" {
   989  				t.Fatalf("Unexpected content: %s", bufs[0])
   990  			}
   991  		})
   992  	}
   993  }
   994  
   995  func TestWSEnqueueCloseMsg(t *testing.T) {
   996  	for _, test := range []struct {
   997  		reason ClosedState
   998  		status int
   999  	}{
  1000  		{ClientClosed, wsCloseStatusNormalClosure},
  1001  		{AuthenticationTimeout, wsCloseStatusPolicyViolation},
  1002  		{AuthenticationViolation, wsCloseStatusPolicyViolation},
  1003  		{SlowConsumerPendingBytes, wsCloseStatusPolicyViolation},
  1004  		{SlowConsumerWriteDeadline, wsCloseStatusPolicyViolation},
  1005  		{MaxAccountConnectionsExceeded, wsCloseStatusPolicyViolation},
  1006  		{MaxConnectionsExceeded, wsCloseStatusPolicyViolation},
  1007  		{MaxControlLineExceeded, wsCloseStatusPolicyViolation},
  1008  		{MaxSubscriptionsExceeded, wsCloseStatusPolicyViolation},
  1009  		{MissingAccount, wsCloseStatusPolicyViolation},
  1010  		{AuthenticationExpired, wsCloseStatusPolicyViolation},
  1011  		{Revocation, wsCloseStatusPolicyViolation},
  1012  		{TLSHandshakeError, wsCloseStatusTLSHandshake},
  1013  		{ParseError, wsCloseStatusProtocolError},
  1014  		{ProtocolViolation, wsCloseStatusProtocolError},
  1015  		{BadClientProtocolVersion, wsCloseStatusProtocolError},
  1016  		{MaxPayloadExceeded, wsCloseStatusMessageTooBig},
  1017  		{ServerShutdown, wsCloseStatusGoingAway},
  1018  		{WriteError, wsCloseStatusAbnormalClosure},
  1019  		{ReadError, wsCloseStatusAbnormalClosure},
  1020  		{StaleConnection, wsCloseStatusAbnormalClosure},
  1021  		{ClosedState(254), wsCloseStatusInternalSrvError},
  1022  	} {
  1023  		t.Run(test.reason.String(), func(t *testing.T) {
  1024  			c, _, _ := testWSSetupForRead()
  1025  			c.wsEnqueueCloseMessage(test.reason)
  1026  			c.mu.Lock()
  1027  			nb, _ := c.collapsePtoNB()
  1028  			c.mu.Unlock()
  1029  			if n := len(nb); n != 1 {
  1030  				t.Fatalf("Expected 1 buffer, got %v", n)
  1031  			}
  1032  			b := nb[0][0]
  1033  			if b&wsFinalBit == 0 {
  1034  				t.Fatalf("Control frame should have been the final flag, it was not set: %v", b)
  1035  			}
  1036  			if b&byte(wsCloseMessage) == 0 {
  1037  				t.Fatalf("Should have been a CLOSE, it wasn't: %v", b)
  1038  			}
  1039  			if status := binary.BigEndian.Uint16(nb[0][2:4]); int(status) != test.status {
  1040  				t.Fatalf("Expected status to be %v, got %v", test.status, status)
  1041  			}
  1042  			if string(nb[0][4:]) != test.reason.String() {
  1043  				t.Fatalf("Unexpected content: %s", nb[0][4:])
  1044  			}
  1045  		})
  1046  	}
  1047  }
  1048  
  1049  type testResponseWriter struct {
  1050  	http.ResponseWriter
  1051  	buf     bytes.Buffer
  1052  	headers http.Header
  1053  	err     error
  1054  	brw     *bufio.ReadWriter
  1055  	conn    *testWSFakeNetConn
  1056  }
  1057  
  1058  func (trw *testResponseWriter) Write(p []byte) (int, error) {
  1059  	return trw.buf.Write(p)
  1060  }
  1061  
  1062  func (trw *testResponseWriter) WriteHeader(status int) {
  1063  	trw.buf.WriteString(fmt.Sprintf("%v", status))
  1064  }
  1065  
  1066  func (trw *testResponseWriter) Header() http.Header {
  1067  	if trw.headers == nil {
  1068  		trw.headers = make(http.Header)
  1069  	}
  1070  	return trw.headers
  1071  }
  1072  
  1073  type testWSFakeNetConn struct {
  1074  	net.Conn
  1075  	wbuf            bytes.Buffer
  1076  	err             error
  1077  	wsOpened        bool
  1078  	isClosed        bool
  1079  	deadlineCleared bool
  1080  }
  1081  
  1082  func (c *testWSFakeNetConn) Write(p []byte) (int, error) {
  1083  	if c.err != nil {
  1084  		return 0, c.err
  1085  	}
  1086  	return c.wbuf.Write(p)
  1087  }
  1088  
  1089  func (c *testWSFakeNetConn) SetDeadline(t time.Time) error {
  1090  	if t.IsZero() {
  1091  		c.deadlineCleared = true
  1092  	}
  1093  	return nil
  1094  }
  1095  
  1096  func (c *testWSFakeNetConn) Close() error {
  1097  	c.isClosed = true
  1098  	return nil
  1099  }
  1100  
  1101  func (trw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
  1102  	if trw.conn == nil {
  1103  		trw.conn = &testWSFakeNetConn{}
  1104  	}
  1105  	trw.conn.wsOpened = true
  1106  	if trw.brw == nil {
  1107  		trw.brw = bufio.NewReadWriter(bufio.NewReader(trw.conn), bufio.NewWriter(trw.conn))
  1108  	}
  1109  	return trw.conn, trw.brw, trw.err
  1110  }
  1111  
  1112  func testWSOptions() *Options {
  1113  	opts := DefaultOptions()
  1114  	opts.DisableShortFirstPing = true
  1115  	opts.Websocket.Host = "127.0.0.1"
  1116  	opts.Websocket.Port = -1
  1117  	opts.NoSystemAccount = true
  1118  	var err error
  1119  	tc := &TLSConfigOpts{
  1120  		CertFile: "./configs/certs/server.pem",
  1121  		KeyFile:  "./configs/certs/key.pem",
  1122  	}
  1123  	opts.Websocket.TLSConfig, err = GenTLSConfig(tc)
  1124  	if err != nil {
  1125  		panic(err)
  1126  	}
  1127  	return opts
  1128  }
  1129  
  1130  func testWSCreateValidReq() *http.Request {
  1131  	req := &http.Request{
  1132  		Method: "GET",
  1133  		Host:   "localhost",
  1134  		Proto:  "HTTP/1.1",
  1135  	}
  1136  	req.Header = make(http.Header)
  1137  	req.Header.Set("Upgrade", "websocket")
  1138  	req.Header.Set("Connection", "Upgrade")
  1139  	req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==")
  1140  	req.Header.Set("Sec-Websocket-Version", "13")
  1141  	return req
  1142  }
  1143  
  1144  func TestWSCheckOrigin(t *testing.T) {
  1145  	notSameOrigin := false
  1146  	sameOrigin := true
  1147  	allowedListEmpty := []string{}
  1148  	someList := []string{"http://host1.com", "http://host2.com:1234"}
  1149  
  1150  	for _, test := range []struct {
  1151  		name       string
  1152  		sameOrigin bool
  1153  		origins    []string
  1154  		reqHost    string
  1155  		reqTLS     bool
  1156  		origin     string
  1157  		err        string
  1158  	}{
  1159  		{"any", notSameOrigin, allowedListEmpty, "", false, "http://any.host.com", ""},
  1160  		{"same origin ok", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:80", ""},
  1161  		{"same origin bad host", sameOrigin, allowedListEmpty, "host.com", false, "http://other.host.com", "not same origin"},
  1162  		{"same origin bad port", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:81", "not same origin"},
  1163  		{"same origin bad scheme", sameOrigin, allowedListEmpty, "host.com", true, "http://host.com", "not same origin"},
  1164  		{"same origin bad uri", sameOrigin, allowedListEmpty, "host.com", false, "@@@://invalid:url:1234", "invalid URI"},
  1165  		{"same origin bad url", sameOrigin, allowedListEmpty, "host.com", false, "http://invalid:url:1234", "too many colons"},
  1166  		{"same origin bad req host", sameOrigin, allowedListEmpty, "invalid:url:1234", false, "http://host.com", "too many colons"},
  1167  		{"no origin same origin ignored", sameOrigin, allowedListEmpty, "", false, "", ""},
  1168  		{"no origin list ignored", sameOrigin, someList, "", false, "", ""},
  1169  		{"no origin same origin and list ignored", sameOrigin, someList, "", false, "", ""},
  1170  		{"allowed from list", notSameOrigin, someList, "", false, "http://host2.com:1234", ""},
  1171  		{"allowed with different path", notSameOrigin, someList, "", false, "http://host1.com/some/path", ""},
  1172  		{"list bad port", notSameOrigin, someList, "", false, "http://host1.com:1234", "not in the allowed list"},
  1173  		{"list bad scheme", notSameOrigin, someList, "", false, "https://host2.com:1234", "not in the allowed list"},
  1174  	} {
  1175  		t.Run(test.name, func(t *testing.T) {
  1176  			opts := DefaultOptions()
  1177  			opts.Websocket.SameOrigin = test.sameOrigin
  1178  			opts.Websocket.AllowedOrigins = test.origins
  1179  			s := &Server{opts: opts}
  1180  			s.wsSetOriginOptions(&opts.Websocket)
  1181  
  1182  			req := testWSCreateValidReq()
  1183  			req.Host = test.reqHost
  1184  			if test.reqTLS {
  1185  				req.TLS = &tls.ConnectionState{}
  1186  			}
  1187  			if test.origin != "" {
  1188  				req.Header.Set("Origin", test.origin)
  1189  			}
  1190  			err := s.websocket.checkOrigin(req)
  1191  			if test.err == "" && err != nil {
  1192  				t.Fatalf("Unexpected error: %v", err)
  1193  			} else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) {
  1194  				t.Fatalf("Expected error %q, got %v", test.err, err)
  1195  			}
  1196  		})
  1197  	}
  1198  }
  1199  
  1200  func TestWSUpgradeValidationErrors(t *testing.T) {
  1201  	for _, test := range []struct {
  1202  		name   string
  1203  		setup  func() (*Options, *testResponseWriter, *http.Request)
  1204  		err    string
  1205  		status int
  1206  	}{
  1207  		{
  1208  			"bad method",
  1209  			func() (*Options, *testResponseWriter, *http.Request) {
  1210  				opts := testWSOptions()
  1211  				req := testWSCreateValidReq()
  1212  				req.Method = "POST"
  1213  				return opts, nil, req
  1214  			},
  1215  			"must be GET",
  1216  			http.StatusMethodNotAllowed,
  1217  		},
  1218  		{
  1219  			"no host",
  1220  			func() (*Options, *testResponseWriter, *http.Request) {
  1221  				opts := testWSOptions()
  1222  				req := testWSCreateValidReq()
  1223  				req.Host = ""
  1224  				return opts, nil, req
  1225  			},
  1226  			"'Host' missing in request",
  1227  			http.StatusBadRequest,
  1228  		},
  1229  		{
  1230  			"invalid upgrade header",
  1231  			func() (*Options, *testResponseWriter, *http.Request) {
  1232  				opts := testWSOptions()
  1233  				req := testWSCreateValidReq()
  1234  				req.Header.Del("Upgrade")
  1235  				return opts, nil, req
  1236  			},
  1237  			"invalid value for header 'Upgrade'",
  1238  			http.StatusBadRequest,
  1239  		},
  1240  		{
  1241  			"invalid connection header",
  1242  			func() (*Options, *testResponseWriter, *http.Request) {
  1243  				opts := testWSOptions()
  1244  				req := testWSCreateValidReq()
  1245  				req.Header.Del("Connection")
  1246  				return opts, nil, req
  1247  			},
  1248  			"invalid value for header 'Connection'",
  1249  			http.StatusBadRequest,
  1250  		},
  1251  		{
  1252  			"no key",
  1253  			func() (*Options, *testResponseWriter, *http.Request) {
  1254  				opts := testWSOptions()
  1255  				req := testWSCreateValidReq()
  1256  				req.Header.Del("Sec-Websocket-Key")
  1257  				return opts, nil, req
  1258  			},
  1259  			"key missing",
  1260  			http.StatusBadRequest,
  1261  		},
  1262  		{
  1263  			"empty key",
  1264  			func() (*Options, *testResponseWriter, *http.Request) {
  1265  				opts := testWSOptions()
  1266  				req := testWSCreateValidReq()
  1267  				req.Header.Set("Sec-Websocket-Key", "")
  1268  				return opts, nil, req
  1269  			},
  1270  			"key missing",
  1271  			http.StatusBadRequest,
  1272  		},
  1273  		{
  1274  			"missing version",
  1275  			func() (*Options, *testResponseWriter, *http.Request) {
  1276  				opts := testWSOptions()
  1277  				req := testWSCreateValidReq()
  1278  				req.Header.Del("Sec-Websocket-Version")
  1279  				return opts, nil, req
  1280  			},
  1281  			"invalid version",
  1282  			http.StatusBadRequest,
  1283  		},
  1284  		{
  1285  			"wrong version",
  1286  			func() (*Options, *testResponseWriter, *http.Request) {
  1287  				opts := testWSOptions()
  1288  				req := testWSCreateValidReq()
  1289  				req.Header.Set("Sec-Websocket-Version", "99")
  1290  				return opts, nil, req
  1291  			},
  1292  			"invalid version",
  1293  			http.StatusBadRequest,
  1294  		},
  1295  		{
  1296  			"origin",
  1297  			func() (*Options, *testResponseWriter, *http.Request) {
  1298  				opts := testWSOptions()
  1299  				opts.Websocket.SameOrigin = true
  1300  				req := testWSCreateValidReq()
  1301  				req.Header.Set("Origin", "http://bad.host.com")
  1302  				return opts, nil, req
  1303  			},
  1304  			"origin not allowed",
  1305  			http.StatusForbidden,
  1306  		},
  1307  		{
  1308  			"hijack error",
  1309  			func() (*Options, *testResponseWriter, *http.Request) {
  1310  				opts := testWSOptions()
  1311  				rw := &testResponseWriter{err: fmt.Errorf("on purpose")}
  1312  				req := testWSCreateValidReq()
  1313  				return opts, rw, req
  1314  			},
  1315  			"on purpose",
  1316  			http.StatusInternalServerError,
  1317  		},
  1318  		{
  1319  			"hijack buffered data",
  1320  			func() (*Options, *testResponseWriter, *http.Request) {
  1321  				opts := testWSOptions()
  1322  				buf := &bytes.Buffer{}
  1323  				buf.WriteString("some data")
  1324  				rw := &testResponseWriter{
  1325  					conn: &testWSFakeNetConn{},
  1326  					brw:  bufio.NewReadWriter(bufio.NewReader(buf), bufio.NewWriter(nil)),
  1327  				}
  1328  				tmp := [1]byte{}
  1329  				io.ReadAtLeast(rw.brw, tmp[:1], 1)
  1330  				req := testWSCreateValidReq()
  1331  				return opts, rw, req
  1332  			},
  1333  			"client sent data before handshake is complete",
  1334  			http.StatusBadRequest,
  1335  		},
  1336  	} {
  1337  		t.Run(test.name, func(t *testing.T) {
  1338  			opts, rw, req := test.setup()
  1339  			if rw == nil {
  1340  				rw = &testResponseWriter{}
  1341  			}
  1342  			s := &Server{opts: opts}
  1343  			s.wsSetOriginOptions(&opts.Websocket)
  1344  			res, err := s.wsUpgrade(rw, req)
  1345  			if err == nil || !strings.Contains(err.Error(), test.err) {
  1346  				t.Fatalf("Should get error %q, got %v", test.err, err)
  1347  			}
  1348  			if res != nil {
  1349  				t.Fatalf("Should not have returned a result, got %v", res)
  1350  			}
  1351  			expected := fmt.Sprintf("%v%s\n", test.status, http.StatusText(test.status))
  1352  			if got := rw.buf.String(); got != expected {
  1353  				t.Fatalf("Expected %q got %q", expected, got)
  1354  			}
  1355  			// Check that if the connection was opened, it is now closed.
  1356  			if rw.conn != nil && rw.conn.wsOpened && !rw.conn.isClosed {
  1357  				t.Fatal("Connection was opened, but has not been closed")
  1358  			}
  1359  		})
  1360  	}
  1361  }
  1362  
  1363  func TestWSUpgradeResponseWriteError(t *testing.T) {
  1364  	opts := testWSOptions()
  1365  	s := &Server{opts: opts}
  1366  	expectedErr := errors.New("on purpose")
  1367  	rw := &testResponseWriter{
  1368  		conn: &testWSFakeNetConn{err: expectedErr},
  1369  	}
  1370  	req := testWSCreateValidReq()
  1371  	res, err := s.wsUpgrade(rw, req)
  1372  	if err != expectedErr {
  1373  		t.Fatalf("Should get error %q, got %v", expectedErr.Error(), err)
  1374  	}
  1375  	if res != nil {
  1376  		t.Fatalf("Should not have returned a result, got %v", res)
  1377  	}
  1378  	if !rw.conn.isClosed {
  1379  		t.Fatal("Connection should have been closed")
  1380  	}
  1381  }
  1382  
  1383  func TestWSUpgradeConnDeadline(t *testing.T) {
  1384  	opts := testWSOptions()
  1385  	opts.Websocket.HandshakeTimeout = time.Second
  1386  	s := &Server{opts: opts}
  1387  	rw := &testResponseWriter{}
  1388  	req := testWSCreateValidReq()
  1389  	res, err := s.wsUpgrade(rw, req)
  1390  	if res == nil || err != nil {
  1391  		t.Fatalf("Unexpected error: %v", err)
  1392  	}
  1393  	if rw.conn.isClosed {
  1394  		t.Fatal("Connection should NOT have been closed")
  1395  	}
  1396  	if !rw.conn.deadlineCleared {
  1397  		t.Fatal("Connection deadline should have been cleared after handshake")
  1398  	}
  1399  }
  1400  
  1401  func TestWSCompressNegotiation(t *testing.T) {
  1402  	// No compression on the server, but client asks
  1403  	opts := testWSOptions()
  1404  	s := &Server{opts: opts}
  1405  	rw := &testResponseWriter{}
  1406  	req := testWSCreateValidReq()
  1407  	req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate")
  1408  	res, err := s.wsUpgrade(rw, req)
  1409  	if res == nil || err != nil {
  1410  		t.Fatalf("Unexpected error: %v", err)
  1411  	}
  1412  	// The http response should not contain "permessage-deflate"
  1413  	output := rw.conn.wbuf.String()
  1414  	if strings.Contains(output, "permessage-deflate") {
  1415  		t.Fatalf("Compression disabled in server so response to client should not contain extension, got %s", output)
  1416  	}
  1417  
  1418  	// Option in the server and client, so compression should be negotiated.
  1419  	s.opts.Websocket.Compression = true
  1420  	rw = &testResponseWriter{}
  1421  	res, err = s.wsUpgrade(rw, req)
  1422  	if res == nil || err != nil {
  1423  		t.Fatalf("Unexpected error: %v", err)
  1424  	}
  1425  	// The http response should not contain "permessage-deflate"
  1426  	output = rw.conn.wbuf.String()
  1427  	if !strings.Contains(output, "permessage-deflate") {
  1428  		t.Fatalf("Compression in server and client request, so response should contain extension, got %s", output)
  1429  	}
  1430  
  1431  	// Option in server but not asked by the client, so response should not contain "permessage-deflate"
  1432  	rw = &testResponseWriter{}
  1433  	req.Header.Del("Sec-Websocket-Extensions")
  1434  	res, err = s.wsUpgrade(rw, req)
  1435  	if res == nil || err != nil {
  1436  		t.Fatalf("Unexpected error: %v", err)
  1437  	}
  1438  	// The http response should not contain "permessage-deflate"
  1439  	output = rw.conn.wbuf.String()
  1440  	if strings.Contains(output, "permessage-deflate") {
  1441  		t.Fatalf("Compression in server but not in client, so response to client should not contain extension, got %s", output)
  1442  	}
  1443  }
  1444  
  1445  func TestWSParseOptions(t *testing.T) {
  1446  	for _, test := range []struct {
  1447  		name     string
  1448  		content  string
  1449  		checkOpt func(*WebsocketOpts) error
  1450  		err      string
  1451  	}{
  1452  		// Negative tests
  1453  		{"bad type", "websocket: []", nil, "to be a map"},
  1454  		{"bad listen", "websocket: { listen: [] }", nil, "port or host:port"},
  1455  		{"bad port", `websocket: { port: "abc" }`, nil, "not int64"},
  1456  		{"bad host", `websocket: { host: 123 }`, nil, "not string"},
  1457  		{"bad advertise type", `websocket: { advertise: 123 }`, nil, "not string"},
  1458  		{"bad tls", `websocket: { tls: 123 }`, nil, "not map[string]interface {}"},
  1459  		{"bad same origin", `websocket: { same_origin: "abc" }`, nil, "not bool"},
  1460  		{"bad allowed origins type", `websocket: { allowed_origins: {} }`, nil, "unsupported type"},
  1461  		{"bad allowed origins values", `websocket: { allowed_origins: [ {} ] }`, nil, "unsupported type in array"},
  1462  		{"bad handshake timeout type", `websocket: { handshake_timeout: [] }`, nil, "unsupported type"},
  1463  		{"bad handshake timeout duration", `websocket: { handshake_timeout: "abc" }`, nil, "invalid duration"},
  1464  		{"unknown field", `websocket: { this_does_not_exist: 123 }`, nil, "unknown"},
  1465  		// Positive tests
  1466  		{"listen port only", `websocket { listen: 1234 }`, func(wo *WebsocketOpts) error {
  1467  			if wo.Port != 1234 {
  1468  				return fmt.Errorf("expected 1234, got %v", wo.Port)
  1469  			}
  1470  			return nil
  1471  		}, ""},
  1472  		{"listen host and port", `websocket { listen: "localhost:1234" }`, func(wo *WebsocketOpts) error {
  1473  			if wo.Host != "localhost" || wo.Port != 1234 {
  1474  				return fmt.Errorf("expected localhost:1234, got %v:%v", wo.Host, wo.Port)
  1475  			}
  1476  			return nil
  1477  		}, ""},
  1478  		{"host", `websocket { host: "localhost" }`, func(wo *WebsocketOpts) error {
  1479  			if wo.Host != "localhost" {
  1480  				return fmt.Errorf("expected localhost, got %v", wo.Host)
  1481  			}
  1482  			return nil
  1483  		}, ""},
  1484  		{"port", `websocket { port: 1234 }`, func(wo *WebsocketOpts) error {
  1485  			if wo.Port != 1234 {
  1486  				return fmt.Errorf("expected 1234, got %v", wo.Port)
  1487  			}
  1488  			return nil
  1489  		}, ""},
  1490  		{"advertise", `websocket { advertise: "host:1234" }`, func(wo *WebsocketOpts) error {
  1491  			if wo.Advertise != "host:1234" {
  1492  				return fmt.Errorf("expected %q, got %q", "host:1234", wo.Advertise)
  1493  			}
  1494  			return nil
  1495  		}, ""},
  1496  		{"same origin", `websocket { same_origin: true }`, func(wo *WebsocketOpts) error {
  1497  			if !wo.SameOrigin {
  1498  				return fmt.Errorf("expected same_origin==true, got %v", wo.SameOrigin)
  1499  			}
  1500  			return nil
  1501  		}, ""},
  1502  		{"allowed origins one only", `websocket { allowed_origins: "https://host.com/" }`, func(wo *WebsocketOpts) error {
  1503  			expected := []string{"https://host.com/"}
  1504  			if !reflect.DeepEqual(wo.AllowedOrigins, expected) {
  1505  				return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins)
  1506  			}
  1507  			return nil
  1508  		}, ""},
  1509  		{"allowed origins array",
  1510  			`
  1511  			websocket {
  1512  				allowed_origins: [
  1513  					"https://host1.com/"
  1514  					"https://host2.com/"
  1515  				]
  1516  			}
  1517  			`, func(wo *WebsocketOpts) error {
  1518  				expected := []string{"https://host1.com/", "https://host2.com/"}
  1519  				if !reflect.DeepEqual(wo.AllowedOrigins, expected) {
  1520  					return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins)
  1521  				}
  1522  				return nil
  1523  			}, ""},
  1524  		{"handshake timeout in whole seconds", `websocket { handshake_timeout: 3 }`, func(wo *WebsocketOpts) error {
  1525  			if wo.HandshakeTimeout != 3*time.Second {
  1526  				return fmt.Errorf("expected handshake to be 3s, got %v", wo.HandshakeTimeout)
  1527  			}
  1528  			return nil
  1529  		}, ""},
  1530  		{"handshake timeout n duration", `websocket { handshake_timeout: "4s" }`, func(wo *WebsocketOpts) error {
  1531  			if wo.HandshakeTimeout != 4*time.Second {
  1532  				return fmt.Errorf("expected handshake to be 4s, got %v", wo.HandshakeTimeout)
  1533  			}
  1534  			return nil
  1535  		}, ""},
  1536  		{"tls config",
  1537  			`
  1538  			websocket {
  1539  				tls {
  1540  					cert_file: "./configs/certs/server.pem"
  1541  					key_file: "./configs/certs/key.pem"
  1542  				}
  1543  			}
  1544  			`, func(wo *WebsocketOpts) error {
  1545  				if wo.TLSConfig == nil {
  1546  					return fmt.Errorf("TLSConfig should have been set")
  1547  				}
  1548  				return nil
  1549  			}, ""},
  1550  		{"compression",
  1551  			`
  1552  			websocket {
  1553  				compression: true
  1554  			}
  1555  			`, func(wo *WebsocketOpts) error {
  1556  				if !wo.Compression {
  1557  					return fmt.Errorf("Compression should have been set")
  1558  				}
  1559  				return nil
  1560  			}, ""},
  1561  		{"jwt cookie",
  1562  			`
  1563  			websocket {
  1564  				jwt_cookie: "jwtcookie"
  1565  			}
  1566  			`, func(wo *WebsocketOpts) error {
  1567  				if wo.JWTCookie != "jwtcookie" {
  1568  					return fmt.Errorf("Invalid JWTCookie value: %q", wo.JWTCookie)
  1569  				}
  1570  				return nil
  1571  			}, ""},
  1572  		{"no auth user",
  1573  			`
  1574  			websocket {
  1575  				no_auth_user: "noauthuser"
  1576  			}
  1577  			`, func(wo *WebsocketOpts) error {
  1578  				if wo.NoAuthUser != "noauthuser" {
  1579  					return fmt.Errorf("Invalid NoAuthUser value: %q", wo.NoAuthUser)
  1580  				}
  1581  				return nil
  1582  			}, ""},
  1583  		{"auth block",
  1584  			`
  1585  			websocket {
  1586  				authorization {
  1587  					user: "webuser"
  1588  					password: "pwd"
  1589  					token: "token"
  1590  					timeout: 2.0
  1591  				}
  1592  			}
  1593  			`, func(wo *WebsocketOpts) error {
  1594  				if wo.Username != "webuser" || wo.Password != "pwd" || wo.Token != "token" || wo.AuthTimeout != 2.0 {
  1595  					return fmt.Errorf("Invalid auth block: %+v", wo)
  1596  				}
  1597  				return nil
  1598  			}, ""},
  1599  		{"auth timeout as int",
  1600  			`
  1601  			websocket {
  1602  				authorization {
  1603  					timeout: 2
  1604  				}
  1605  			}
  1606  			`, func(wo *WebsocketOpts) error {
  1607  				if wo.AuthTimeout != 2.0 {
  1608  					return fmt.Errorf("Invalid auth timeout: %v", wo.AuthTimeout)
  1609  				}
  1610  				return nil
  1611  			}, ""},
  1612  	} {
  1613  		t.Run(test.name, func(t *testing.T) {
  1614  			conf := createConfFile(t, []byte(test.content))
  1615  			o, err := ProcessConfigFile(conf)
  1616  			if test.err != _EMPTY_ {
  1617  				if err == nil || !strings.Contains(err.Error(), test.err) {
  1618  					t.Fatalf("For content: %q, expected error about %q, got %v", test.content, test.err, err)
  1619  				}
  1620  				return
  1621  			} else if err != nil {
  1622  				t.Fatalf("Unexpected error for content %q: %v", test.content, err)
  1623  			}
  1624  			if err := test.checkOpt(&o.Websocket); err != nil {
  1625  				t.Fatalf("Incorrect option for content %q: %v", test.content, err.Error())
  1626  			}
  1627  		})
  1628  	}
  1629  }
  1630  
  1631  func TestWSValidateOptions(t *testing.T) {
  1632  	nwso := DefaultOptions()
  1633  	wso := testWSOptions()
  1634  	for _, test := range []struct {
  1635  		name    string
  1636  		getOpts func() *Options
  1637  		err     string
  1638  	}{
  1639  		{"websocket disabled", func() *Options { return nwso.Clone() }, ""},
  1640  		{"no tls", func() *Options { o := wso.Clone(); o.Websocket.TLSConfig = nil; return o }, "requires TLS configuration"},
  1641  		{"bad url in allowed list", func() *Options {
  1642  			o := wso.Clone()
  1643  			o.Websocket.AllowedOrigins = []string{"http://this:is:bad:url"}
  1644  			return o
  1645  		}, "unable to parse"},
  1646  		{"missing trusted configuration", func() *Options {
  1647  			o := wso.Clone()
  1648  			o.Websocket.JWTCookie = "jwt"
  1649  			return o
  1650  		}, "keys configuration is required"},
  1651  		{"websocket username not allowed if users specified", func() *Options {
  1652  			o := wso.Clone()
  1653  			o.Nkeys = []*NkeyUser{{Nkey: "abc"}}
  1654  			o.Websocket.Username = "b"
  1655  			o.Websocket.Password = "pwd"
  1656  			return o
  1657  		}, "websocket authentication username not compatible with presence of users/nkeys"},
  1658  		{"websocket token not allowed if users specified", func() *Options {
  1659  			o := wso.Clone()
  1660  			o.Nkeys = []*NkeyUser{{Nkey: "abc"}}
  1661  			o.Websocket.Token = "mytoken"
  1662  			return o
  1663  		}, "websocket authentication token not compatible with presence of users/nkeys"},
  1664  	} {
  1665  		t.Run(test.name, func(t *testing.T) {
  1666  			err := validateWebsocketOptions(test.getOpts())
  1667  			if test.err == "" && err != nil {
  1668  				t.Fatalf("Unexpected error: %v", err)
  1669  			} else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) {
  1670  				t.Fatalf("Expected error to contain %q, got %v", test.err, err)
  1671  			}
  1672  		})
  1673  	}
  1674  }
  1675  
  1676  func TestWSSetOriginOptions(t *testing.T) {
  1677  	o := testWSOptions()
  1678  	for _, test := range []struct {
  1679  		content string
  1680  		err     string
  1681  	}{
  1682  		{"@@@://host.com/", "invalid URI"},
  1683  		{"http://this:is:bad:url/", "invalid port"},
  1684  	} {
  1685  		t.Run(test.err, func(t *testing.T) {
  1686  			o.Websocket.AllowedOrigins = []string{test.content}
  1687  			s := &Server{}
  1688  			l := &captureErrorLogger{errCh: make(chan string, 1)}
  1689  			s.SetLogger(l, false, false)
  1690  			s.wsSetOriginOptions(&o.Websocket)
  1691  			select {
  1692  			case e := <-l.errCh:
  1693  				if !strings.Contains(e, test.err) {
  1694  					t.Fatalf("Unexpected error: %v", e)
  1695  				}
  1696  			case <-time.After(50 * time.Millisecond):
  1697  				t.Fatalf("Did not get the error")
  1698  			}
  1699  
  1700  		})
  1701  	}
  1702  }
  1703  
  1704  type captureFatalLogger struct {
  1705  	DummyLogger
  1706  	fatalCh chan string
  1707  }
  1708  
  1709  func (l *captureFatalLogger) Fatalf(format string, v ...interface{}) {
  1710  	select {
  1711  	case l.fatalCh <- fmt.Sprintf(format, v...):
  1712  	default:
  1713  	}
  1714  }
  1715  
  1716  func TestWSFailureToStartServer(t *testing.T) {
  1717  	// Create a listener to use a port
  1718  	l, err := net.Listen("tcp", "127.0.0.1:0")
  1719  	if err != nil {
  1720  		t.Fatalf("Error listening: %v", err)
  1721  	}
  1722  	defer l.Close()
  1723  
  1724  	o := testWSOptions()
  1725  	// Make sure we don't have unnecessary listen ports opened.
  1726  	o.HTTPPort = 0
  1727  	o.Cluster.Port = 0
  1728  	o.Gateway.Name = ""
  1729  	o.Gateway.Port = 0
  1730  	o.LeafNode.Port = 0
  1731  	o.Websocket.Port = l.Addr().(*net.TCPAddr).Port
  1732  	s, err := NewServer(o)
  1733  	if err != nil {
  1734  		t.Fatalf("Error creating server: %v", err)
  1735  	}
  1736  	defer s.Shutdown()
  1737  	logger := &captureFatalLogger{fatalCh: make(chan string, 1)}
  1738  	s.SetLogger(logger, false, false)
  1739  
  1740  	wg := sync.WaitGroup{}
  1741  	wg.Add(1)
  1742  	go func() {
  1743  		s.Start()
  1744  		wg.Done()
  1745  	}()
  1746  
  1747  	select {
  1748  	case e := <-logger.fatalCh:
  1749  		if !strings.Contains(e, "Unable to listen") {
  1750  			t.Fatalf("Unexpected error: %v", e)
  1751  		}
  1752  	case <-time.After(2 * time.Second):
  1753  		t.Fatalf("Should have reported a fatal error")
  1754  	}
  1755  	// Since this is a test and the process does not actually
  1756  	// exit on Fatal error, wait for the client port to be
  1757  	// ready so when we shutdown we don't leave the accept
  1758  	// loop hanging.
  1759  	checkFor(t, time.Second, 15*time.Millisecond, func() error {
  1760  		s.mu.Lock()
  1761  		ready := s.listener != nil
  1762  		s.mu.Unlock()
  1763  		if !ready {
  1764  			return fmt.Errorf("client accept loop not started yet")
  1765  		}
  1766  		return nil
  1767  	})
  1768  	s.Shutdown()
  1769  	wg.Wait()
  1770  }
  1771  
  1772  func TestWSAbnormalFailureOfWebServer(t *testing.T) {
  1773  	o := testWSOptions()
  1774  	s := RunServer(o)
  1775  	defer s.Shutdown()
  1776  	logger := &captureFatalLogger{fatalCh: make(chan string, 1)}
  1777  	s.SetLogger(logger, false, false)
  1778  
  1779  	// Now close the WS listener to cause a WebServer error
  1780  	s.mu.Lock()
  1781  	s.websocket.listener.Close()
  1782  	s.mu.Unlock()
  1783  
  1784  	select {
  1785  	case e := <-logger.fatalCh:
  1786  		if !strings.Contains(e, "websocket listener error") {
  1787  			t.Fatalf("Unexpected error: %v", e)
  1788  		}
  1789  	case <-time.After(2 * time.Second):
  1790  		t.Fatalf("Should have reported a fatal error")
  1791  	}
  1792  }
  1793  
  1794  type testWSClientOptions struct {
  1795  	compress, web bool
  1796  	host          string
  1797  	port          int
  1798  	extraHeaders  map[string][]string
  1799  	noTLS         bool
  1800  	path          string
  1801  }
  1802  
  1803  func testNewWSClient(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte) {
  1804  	t.Helper()
  1805  	c, br, info, err := testNewWSClientWithError(t, o)
  1806  	if err != nil {
  1807  		t.Fatal(err)
  1808  	}
  1809  	return c, br, info
  1810  }
  1811  
  1812  func testNewWSClientWithError(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte, error) {
  1813  	addr := fmt.Sprintf("%s:%d", o.host, o.port)
  1814  	wsc, err := net.Dial("tcp", addr)
  1815  	if err != nil {
  1816  		return nil, nil, nil, fmt.Errorf("Error creating ws connection: %v", err)
  1817  	}
  1818  	if !o.noTLS {
  1819  		wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
  1820  		wsc.SetDeadline(time.Now().Add(time.Second))
  1821  		if err := wsc.(*tls.Conn).Handshake(); err != nil {
  1822  			return nil, nil, nil, fmt.Errorf("Error during handshake: %v", err)
  1823  		}
  1824  		wsc.SetDeadline(time.Time{})
  1825  	}
  1826  	req := testWSCreateValidReq()
  1827  	if o.compress {
  1828  		req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate")
  1829  	}
  1830  	if o.web {
  1831  		req.Header.Set("User-Agent", "Mozilla/5.0")
  1832  	}
  1833  	if len(o.extraHeaders) > 0 {
  1834  		for hdr, values := range o.extraHeaders {
  1835  			if len(values) == 0 {
  1836  				req.Header.Set(hdr, _EMPTY_)
  1837  				continue
  1838  			}
  1839  			req.Header.Set(hdr, values[0])
  1840  			for i := 1; i < len(values); i++ {
  1841  				req.Header.Add(hdr, values[i])
  1842  			}
  1843  		}
  1844  	}
  1845  	req.URL, _ = url.Parse("wss://" + addr + o.path)
  1846  	if err := req.Write(wsc); err != nil {
  1847  		return nil, nil, nil, fmt.Errorf("Error sending request: %v", err)
  1848  	}
  1849  	br := bufio.NewReader(wsc)
  1850  	resp, err := http.ReadResponse(br, req)
  1851  	if err != nil {
  1852  		return nil, nil, nil, fmt.Errorf("Error reading response: %v", err)
  1853  	}
  1854  	defer resp.Body.Close()
  1855  	if resp.StatusCode != http.StatusSwitchingProtocols {
  1856  		return nil, nil, nil, fmt.Errorf("Expected response status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
  1857  	}
  1858  	var info []byte
  1859  	if o.path == mqttWSPath {
  1860  		if v := resp.Header[wsSecProto]; len(v) != 1 || v[0] != wsMQTTSecProtoVal {
  1861  			return nil, nil, nil, fmt.Errorf("No mqtt protocol in header: %v", resp.Header)
  1862  		}
  1863  	} else {
  1864  		// Wait for the INFO
  1865  		info = testWSReadFrame(t, br)
  1866  		if !bytes.HasPrefix(info, []byte("INFO {")) {
  1867  			return nil, nil, nil, fmt.Errorf("Expected INFO, got %s", info)
  1868  		}
  1869  	}
  1870  	return wsc, br, info, nil
  1871  }
  1872  
  1873  type testClaimsOptions struct {
  1874  	nac            *jwt.AccountClaims
  1875  	nuc            *jwt.UserClaims
  1876  	connectRequest interface{}
  1877  	dontSign       bool
  1878  	expectAnswer   string
  1879  }
  1880  
  1881  func testWSWithClaims(t *testing.T, s *Server, o testWSClientOptions, tclm testClaimsOptions) (kp nkeys.KeyPair, conn net.Conn, rdr *bufio.Reader, auth_was_required bool) {
  1882  	t.Helper()
  1883  
  1884  	okp, _ := nkeys.FromSeed(oSeed)
  1885  
  1886  	akp, _ := nkeys.CreateAccount()
  1887  	apub, _ := akp.PublicKey()
  1888  	if tclm.nac == nil {
  1889  		tclm.nac = jwt.NewAccountClaims(apub)
  1890  	} else {
  1891  		tclm.nac.Subject = apub
  1892  	}
  1893  	ajwt, err := tclm.nac.Encode(okp)
  1894  	if err != nil {
  1895  		t.Fatalf("Error generating account JWT: %v", err)
  1896  	}
  1897  
  1898  	nkp, _ := nkeys.CreateUser()
  1899  	pub, _ := nkp.PublicKey()
  1900  	if tclm.nuc == nil {
  1901  		tclm.nuc = jwt.NewUserClaims(pub)
  1902  	} else {
  1903  		tclm.nuc.Subject = pub
  1904  	}
  1905  	jwt, err := tclm.nuc.Encode(akp)
  1906  	if err != nil {
  1907  		t.Fatalf("Error generating user JWT: %v", err)
  1908  	}
  1909  
  1910  	addAccountToMemResolver(s, apub, ajwt)
  1911  
  1912  	c, cr, l := testNewWSClient(t, o)
  1913  
  1914  	var info struct {
  1915  		Nonce        string `json:"nonce,omitempty"`
  1916  		AuthRequired bool   `json:"auth_required,omitempty"`
  1917  	}
  1918  
  1919  	if err := json.Unmarshal([]byte(l[5:]), &info); err != nil {
  1920  		t.Fatal(err)
  1921  	}
  1922  	if info.AuthRequired {
  1923  		cs := ""
  1924  		if tclm.connectRequest != nil {
  1925  			customReq, err := json.Marshal(tclm.connectRequest)
  1926  			if err != nil {
  1927  				t.Fatal(err)
  1928  			}
  1929  			// PING needed to flush the +OK/-ERR to us.
  1930  			cs = fmt.Sprintf("CONNECT %v\r\nPING\r\n", string(customReq))
  1931  		} else if !tclm.dontSign {
  1932  			// Sign Nonce
  1933  			sigraw, _ := nkp.Sign([]byte(info.Nonce))
  1934  			sig := base64.RawURLEncoding.EncodeToString(sigraw)
  1935  			cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"sig\":\"%s\",\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt, sig)
  1936  		} else {
  1937  			cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt)
  1938  		}
  1939  		wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(cs))
  1940  		c.Write(wsmsg)
  1941  		l = testWSReadFrame(t, cr)
  1942  		if !strings.HasPrefix(string(l), tclm.expectAnswer) {
  1943  			t.Fatalf("Expected %q, got %q", tclm.expectAnswer, l)
  1944  		}
  1945  	}
  1946  	return akp, c, cr, info.AuthRequired
  1947  }
  1948  
  1949  func setupAddTrusted(o *Options) {
  1950  	kp, _ := nkeys.FromSeed(oSeed)
  1951  	pub, _ := kp.PublicKey()
  1952  	o.TrustedKeys = []string{pub}
  1953  }
  1954  
  1955  func setupAddCookie(o *Options) {
  1956  	o.Websocket.JWTCookie = "jwt"
  1957  }
  1958  
  1959  func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int, cookies ...string) (net.Conn, *bufio.Reader, []byte) {
  1960  	t.Helper()
  1961  	opts := testWSClientOptions{
  1962  		compress: compress,
  1963  		web:      web,
  1964  		host:     host,
  1965  		port:     port,
  1966  	}
  1967  
  1968  	if len(cookies) > 0 {
  1969  		opts.extraHeaders = map[string][]string{}
  1970  		opts.extraHeaders["Cookie"] = cookies
  1971  	}
  1972  	return testNewWSClient(t, opts)
  1973  }
  1974  
  1975  func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) {
  1976  	wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port)
  1977  	// Send CONNECT and PING
  1978  	wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
  1979  	if _, err := wsc.Write(wsmsg); err != nil {
  1980  		t.Fatalf("Error sending message: %v", err)
  1981  	}
  1982  	// Wait for the PONG
  1983  	if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  1984  		t.Fatalf("Expected PONG, got %s", msg)
  1985  	}
  1986  	return wsc, br
  1987  }
  1988  
  1989  func testWSReadFrame(t testing.TB, br *bufio.Reader) []byte {
  1990  	t.Helper()
  1991  	fh := [2]byte{}
  1992  	if _, err := io.ReadAtLeast(br, fh[:2], 2); err != nil {
  1993  		t.Fatalf("Error reading frame: %v", err)
  1994  	}
  1995  	fc := fh[0]&wsRsv1Bit != 0
  1996  	sb := fh[1]
  1997  	size := 0
  1998  	switch {
  1999  	case sb <= 125:
  2000  		size = int(sb)
  2001  	case sb == 126:
  2002  		tmp := [2]byte{}
  2003  		if _, err := io.ReadAtLeast(br, tmp[:2], 2); err != nil {
  2004  			t.Fatalf("Error reading frame: %v", err)
  2005  		}
  2006  		size = int(binary.BigEndian.Uint16(tmp[:2]))
  2007  	case sb == 127:
  2008  		tmp := [8]byte{}
  2009  		if _, err := io.ReadAtLeast(br, tmp[:8], 8); err != nil {
  2010  			t.Fatalf("Error reading frame: %v", err)
  2011  		}
  2012  		size = int(binary.BigEndian.Uint64(tmp[:8]))
  2013  	}
  2014  	buf := make([]byte, size)
  2015  	if _, err := io.ReadAtLeast(br, buf, size); err != nil {
  2016  		t.Fatalf("Error reading frame: %v", err)
  2017  	}
  2018  	if !fc {
  2019  		return buf
  2020  	}
  2021  	buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff)
  2022  	dbr := bytes.NewBuffer(buf)
  2023  	d := flate.NewReader(dbr)
  2024  	uncompressed, err := io.ReadAll(d)
  2025  	if err != nil {
  2026  		t.Fatalf("Error reading frame: %v", err)
  2027  	}
  2028  	return uncompressed
  2029  }
  2030  
  2031  func TestWSPubSub(t *testing.T) {
  2032  	for _, test := range []struct {
  2033  		name        string
  2034  		compression bool
  2035  	}{
  2036  		{"no compression", false},
  2037  		{"compression", true},
  2038  	} {
  2039  		t.Run(test.name, func(t *testing.T) {
  2040  			o := testWSOptions()
  2041  			if test.compression {
  2042  				o.Websocket.Compression = true
  2043  			}
  2044  			s := RunServer(o)
  2045  			defer s.Shutdown()
  2046  
  2047  			// Create a regular client to subscribe
  2048  			nc := natsConnect(t, s.ClientURL())
  2049  			defer nc.Close()
  2050  			nsub := natsSubSync(t, nc, "foo")
  2051  			checkExpectedSubs(t, 1, s)
  2052  
  2053  			// Now create a WS client and send a message on "foo"
  2054  			wsc, br := testWSCreateClient(t, test.compression, false, o.Websocket.Host, o.Websocket.Port)
  2055  			defer wsc.Close()
  2056  
  2057  			// Send a WS message for "PUB foo 2\r\nok\r\n"
  2058  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n"))
  2059  			if _, err := wsc.Write(wsmsg); err != nil {
  2060  				t.Fatalf("Error sending message: %v", err)
  2061  			}
  2062  
  2063  			// Now check that message is received
  2064  			msg := natsNexMsg(t, nsub, time.Second)
  2065  			if string(msg.Data) != "from ws" {
  2066  				t.Fatalf("Expected message to be %q, got %q", "ok", string(msg.Data))
  2067  			}
  2068  
  2069  			// Now do reverse, create a subscription on WS client on bar
  2070  			wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB bar 1\r\n"))
  2071  			if _, err := wsc.Write(wsmsg); err != nil {
  2072  				t.Fatalf("Error sending subscription: %v", err)
  2073  			}
  2074  			// Wait for it to be registered on server
  2075  			checkExpectedSubs(t, 2, s)
  2076  			// Now publish from NATS connection and verify received on WS client
  2077  			natsPub(t, nc, "bar", []byte("from nats"))
  2078  			natsFlush(t, nc)
  2079  
  2080  			// Check for the "from nats" message...
  2081  			// Set some deadline so we are not stuck forever on failure
  2082  			wsc.SetReadDeadline(time.Now().Add(10 * time.Second))
  2083  			ok := 0
  2084  			for {
  2085  				line, _, err := br.ReadLine()
  2086  				if err != nil {
  2087  					t.Fatalf("Error reading: %v", err)
  2088  				}
  2089  				// Note that this works even in compression test because those
  2090  				// texts are likely not to be compressed, but compression code is
  2091  				// still executed.
  2092  				if ok == 0 && bytes.Contains(line, []byte("MSG bar 1 9")) {
  2093  					ok = 1
  2094  					continue
  2095  				} else if ok == 1 && bytes.Contains(line, []byte("from nats")) {
  2096  					break
  2097  				}
  2098  			}
  2099  		})
  2100  	}
  2101  }
  2102  
  2103  func TestWSTLSConnection(t *testing.T) {
  2104  	o := testWSOptions()
  2105  	s := RunServer(o)
  2106  	defer s.Shutdown()
  2107  
  2108  	addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
  2109  
  2110  	for _, test := range []struct {
  2111  		name   string
  2112  		useTLS bool
  2113  		status int
  2114  	}{
  2115  		{"client uses TLS", true, http.StatusSwitchingProtocols},
  2116  		{"client does not use TLS", false, http.StatusBadRequest},
  2117  	} {
  2118  		t.Run(test.name, func(t *testing.T) {
  2119  			wsc, err := net.Dial("tcp", addr)
  2120  			if err != nil {
  2121  				t.Fatalf("Error creating ws connection: %v", err)
  2122  			}
  2123  			defer wsc.Close()
  2124  			if test.useTLS {
  2125  				wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
  2126  				if err := wsc.(*tls.Conn).Handshake(); err != nil {
  2127  					t.Fatalf("Error during handshake: %v", err)
  2128  				}
  2129  			}
  2130  			req := testWSCreateValidReq()
  2131  			var scheme string
  2132  			if test.useTLS {
  2133  				scheme = "s"
  2134  			}
  2135  			req.URL, _ = url.Parse("ws" + scheme + "://" + addr)
  2136  			if err := req.Write(wsc); err != nil {
  2137  				t.Fatalf("Error sending request: %v", err)
  2138  			}
  2139  			br := bufio.NewReader(wsc)
  2140  			resp, err := http.ReadResponse(br, req)
  2141  			if err != nil {
  2142  				t.Fatalf("Error reading response: %v", err)
  2143  			}
  2144  			defer resp.Body.Close()
  2145  			if resp.StatusCode != test.status {
  2146  				t.Fatalf("Expected status %v, got %v", test.status, resp.StatusCode)
  2147  			}
  2148  		})
  2149  	}
  2150  }
  2151  
  2152  func TestWSTLSVerifyClientCert(t *testing.T) {
  2153  	o := testWSOptions()
  2154  	tc := &TLSConfigOpts{
  2155  		CertFile: "../test/configs/certs/server-cert.pem",
  2156  		KeyFile:  "../test/configs/certs/server-key.pem",
  2157  		CaFile:   "../test/configs/certs/ca.pem",
  2158  		Verify:   true,
  2159  	}
  2160  	tlsc, err := GenTLSConfig(tc)
  2161  	if err != nil {
  2162  		t.Fatalf("Error creating tls config: %v", err)
  2163  	}
  2164  	o.Websocket.TLSConfig = tlsc
  2165  	s := RunServer(o)
  2166  	defer s.Shutdown()
  2167  
  2168  	addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
  2169  
  2170  	for _, test := range []struct {
  2171  		name        string
  2172  		provideCert bool
  2173  	}{
  2174  		{"client provides cert", true},
  2175  		{"client does not provide cert", false},
  2176  	} {
  2177  		t.Run(test.name, func(t *testing.T) {
  2178  			wsc, err := net.Dial("tcp", addr)
  2179  			if err != nil {
  2180  				t.Fatalf("Error creating ws connection: %v", err)
  2181  			}
  2182  			defer wsc.Close()
  2183  			tlsc := &tls.Config{}
  2184  			if test.provideCert {
  2185  				tc := &TLSConfigOpts{
  2186  					CertFile: "../test/configs/certs/client-cert.pem",
  2187  					KeyFile:  "../test/configs/certs/client-key.pem",
  2188  				}
  2189  				var err error
  2190  				tlsc, err = GenTLSConfig(tc)
  2191  				if err != nil {
  2192  					t.Fatalf("Error generating tls config: %v", err)
  2193  				}
  2194  			}
  2195  			tlsc.InsecureSkipVerify = true
  2196  			wsc = tls.Client(wsc, tlsc)
  2197  			if err := wsc.(*tls.Conn).Handshake(); err != nil {
  2198  				t.Fatalf("Error during handshake: %v", err)
  2199  			}
  2200  			req := testWSCreateValidReq()
  2201  			req.URL, _ = url.Parse("wss://" + addr)
  2202  			if err := req.Write(wsc); err != nil {
  2203  				t.Fatalf("Error sending request: %v", err)
  2204  			}
  2205  			br := bufio.NewReader(wsc)
  2206  			resp, err := http.ReadResponse(br, req)
  2207  			if resp != nil {
  2208  				resp.Body.Close()
  2209  			}
  2210  			if !test.provideCert {
  2211  				if err == nil {
  2212  					t.Fatal("Expected error, did not get one")
  2213  				} else if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "certificate required") {
  2214  					t.Fatalf("Unexpected error: %v", err)
  2215  				}
  2216  				return
  2217  			}
  2218  			if err != nil {
  2219  				t.Fatalf("Unexpected error: %v", err)
  2220  			}
  2221  			if resp.StatusCode != http.StatusSwitchingProtocols {
  2222  				t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
  2223  			}
  2224  		})
  2225  	}
  2226  }
  2227  
  2228  func testCreateAllowedConnectionTypes(list []string) map[string]struct{} {
  2229  	if len(list) == 0 {
  2230  		return nil
  2231  	}
  2232  	m := make(map[string]struct{}, len(list))
  2233  	for _, l := range list {
  2234  		m[l] = struct{}{}
  2235  	}
  2236  	return m
  2237  }
  2238  
  2239  func TestWSTLSVerifyAndMap(t *testing.T) {
  2240  	accName := "MyAccount"
  2241  	acc := NewAccount(accName)
  2242  	certUserName := "CN=example.com,OU=NATS.io"
  2243  	users := []*User{{Username: certUserName, Account: acc}}
  2244  
  2245  	for _, test := range []struct {
  2246  		name        string
  2247  		filtering   bool
  2248  		provideCert bool
  2249  	}{
  2250  		{"no filtering, client provides cert", false, true},
  2251  		{"no filtering, client does not provide cert", false, false},
  2252  		{"filtering, client provides cert", true, true},
  2253  		{"filtering, client does not provide cert", true, false},
  2254  		{"no users override, client provides cert", false, true},
  2255  		{"no users override, client does not provide cert", false, false},
  2256  		{"users override, client provides cert", true, true},
  2257  		{"users override, client does not provide cert", true, false},
  2258  	} {
  2259  		t.Run(test.name, func(t *testing.T) {
  2260  			o := testWSOptions()
  2261  			o.Accounts = []*Account{acc}
  2262  			o.Users = users
  2263  			if test.filtering {
  2264  				o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket})
  2265  			}
  2266  			tc := &TLSConfigOpts{
  2267  				CertFile: "../test/configs/certs/tlsauth/server.pem",
  2268  				KeyFile:  "../test/configs/certs/tlsauth/server-key.pem",
  2269  				CaFile:   "../test/configs/certs/tlsauth/ca.pem",
  2270  				Verify:   true,
  2271  			}
  2272  			tlsc, err := GenTLSConfig(tc)
  2273  			if err != nil {
  2274  				t.Fatalf("Error creating tls config: %v", err)
  2275  			}
  2276  			o.Websocket.TLSConfig = tlsc
  2277  			o.Websocket.TLSMap = true
  2278  			s := RunServer(o)
  2279  			defer s.Shutdown()
  2280  
  2281  			addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
  2282  			wsc, err := net.Dial("tcp", addr)
  2283  			if err != nil {
  2284  				t.Fatalf("Error creating ws connection: %v", err)
  2285  			}
  2286  			defer wsc.Close()
  2287  			tlscc := &tls.Config{}
  2288  			if test.provideCert {
  2289  				tc := &TLSConfigOpts{
  2290  					CertFile: "../test/configs/certs/tlsauth/client.pem",
  2291  					KeyFile:  "../test/configs/certs/tlsauth/client-key.pem",
  2292  				}
  2293  				var err error
  2294  				tlscc, err = GenTLSConfig(tc)
  2295  				if err != nil {
  2296  					t.Fatalf("Error generating tls config: %v", err)
  2297  				}
  2298  			}
  2299  			tlscc.InsecureSkipVerify = true
  2300  			wsc = tls.Client(wsc, tlscc)
  2301  			if err := wsc.(*tls.Conn).Handshake(); err != nil {
  2302  				t.Fatalf("Error during handshake: %v", err)
  2303  			}
  2304  			req := testWSCreateValidReq()
  2305  			req.URL, _ = url.Parse("wss://" + addr)
  2306  			if err := req.Write(wsc); err != nil {
  2307  				t.Fatalf("Error sending request: %v", err)
  2308  			}
  2309  			br := bufio.NewReader(wsc)
  2310  			resp, err := http.ReadResponse(br, req)
  2311  			if resp != nil {
  2312  				resp.Body.Close()
  2313  			}
  2314  			if !test.provideCert {
  2315  				if err == nil {
  2316  					t.Fatal("Expected error, did not get one")
  2317  				} else if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "certificate required") {
  2318  					t.Fatalf("Unexpected error: %v", err)
  2319  				}
  2320  				return
  2321  			}
  2322  			if err != nil {
  2323  				t.Fatalf("Unexpected error: %v", err)
  2324  			}
  2325  			if resp.StatusCode != http.StatusSwitchingProtocols {
  2326  				t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode)
  2327  			}
  2328  			// Wait for the INFO
  2329  			l := testWSReadFrame(t, br)
  2330  			if !bytes.HasPrefix(l, []byte("INFO {")) {
  2331  				t.Fatalf("Expected INFO, got %s", l)
  2332  			}
  2333  			var info serverInfo
  2334  			if err := json.Unmarshal(l[5:], &info); err != nil {
  2335  				t.Fatalf("Unable to unmarshal info: %v", err)
  2336  			}
  2337  			// Send CONNECT and PING
  2338  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
  2339  			if _, err := wsc.Write(wsmsg); err != nil {
  2340  				t.Fatalf("Error sending message: %v", err)
  2341  			}
  2342  			// Wait for the PONG
  2343  			if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  2344  				t.Fatalf("Expected PONG, got %s", msg)
  2345  			}
  2346  
  2347  			var uname string
  2348  			var accname string
  2349  			c := s.getClient(info.CID)
  2350  			if c != nil {
  2351  				c.mu.Lock()
  2352  				uname = c.opts.Username
  2353  				if c.acc != nil {
  2354  					accname = c.acc.GetName()
  2355  				}
  2356  				c.mu.Unlock()
  2357  			}
  2358  			if uname != certUserName {
  2359  				t.Fatalf("Expected username %q, got %q", certUserName, uname)
  2360  			}
  2361  			if accname != accName {
  2362  				t.Fatalf("Expected account %q, got %v", accName, accname)
  2363  			}
  2364  		})
  2365  	}
  2366  }
  2367  
  2368  func TestWSHandshakeTimeout(t *testing.T) {
  2369  	o := testWSOptions()
  2370  	o.Websocket.HandshakeTimeout = time.Millisecond
  2371  	tc := &TLSConfigOpts{
  2372  		CertFile: "./configs/certs/server.pem",
  2373  		KeyFile:  "./configs/certs/key.pem",
  2374  	}
  2375  	o.Websocket.TLSConfig, _ = GenTLSConfig(tc)
  2376  	s := RunServer(o)
  2377  	defer s.Shutdown()
  2378  
  2379  	logger := &captureErrorLogger{errCh: make(chan string, 1)}
  2380  	s.SetLogger(logger, false, false)
  2381  
  2382  	addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port)
  2383  	wsc, err := net.Dial("tcp", addr)
  2384  	if err != nil {
  2385  		t.Fatalf("Error creating ws connection: %v", err)
  2386  	}
  2387  	defer wsc.Close()
  2388  
  2389  	// Delay the handshake
  2390  	wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
  2391  	time.Sleep(20 * time.Millisecond)
  2392  	// We expect error since the server should have cut us off
  2393  	if err := wsc.(*tls.Conn).Handshake(); err == nil {
  2394  		t.Fatal("Expected error during handshake")
  2395  	}
  2396  
  2397  	// Check that server logs error
  2398  	select {
  2399  	case e := <-logger.errCh:
  2400  		// Check that log starts with "websocket: "
  2401  		if !strings.HasPrefix(e, "websocket: ") {
  2402  			t.Fatalf("Wrong log line start: %s", e)
  2403  		}
  2404  		if !strings.Contains(e, "timeout") {
  2405  			t.Fatalf("Unexpected error: %v", e)
  2406  		}
  2407  	case <-time.After(time.Second):
  2408  		t.Fatalf("Should have timed-out")
  2409  	}
  2410  }
  2411  
  2412  func TestWSServerReportUpgradeFailure(t *testing.T) {
  2413  	o := testWSOptions()
  2414  	s := RunServer(o)
  2415  	defer s.Shutdown()
  2416  
  2417  	logger := &captureErrorLogger{errCh: make(chan string, 1)}
  2418  	s.SetLogger(logger, false, false)
  2419  
  2420  	addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port)
  2421  	req := testWSCreateValidReq()
  2422  	req.URL, _ = url.Parse("wss://" + addr)
  2423  
  2424  	wsc, err := net.Dial("tcp", addr)
  2425  	if err != nil {
  2426  		t.Fatalf("Error creating ws connection: %v", err)
  2427  	}
  2428  	defer wsc.Close()
  2429  	wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true})
  2430  	if err := wsc.(*tls.Conn).Handshake(); err != nil {
  2431  		t.Fatalf("Error during handshake: %v", err)
  2432  	}
  2433  	// Remove a required field from the request to have it fail
  2434  	req.Header.Del("Connection")
  2435  	// Send the request
  2436  	if err := req.Write(wsc); err != nil {
  2437  		t.Fatalf("Error sending request: %v", err)
  2438  	}
  2439  	br := bufio.NewReader(wsc)
  2440  	resp, err := http.ReadResponse(br, req)
  2441  	if err != nil {
  2442  		t.Fatalf("Error reading response: %v", err)
  2443  	}
  2444  	defer resp.Body.Close()
  2445  	if resp.StatusCode != http.StatusBadRequest {
  2446  		t.Fatalf("Expected status %v, got %v", http.StatusBadRequest, resp.StatusCode)
  2447  	}
  2448  
  2449  	// Check that server logs error
  2450  	select {
  2451  	case e := <-logger.errCh:
  2452  		if !strings.Contains(e, "invalid value for header 'Connection'") {
  2453  			t.Fatalf("Unexpected error: %v", e)
  2454  		}
  2455  		// The client IP's local should be printed as a remote from server perspective.
  2456  		clientIP := wsc.LocalAddr().String()
  2457  		if !strings.HasPrefix(e, clientIP) {
  2458  			t.Fatalf("IP should have been logged, it was not: %v", e)
  2459  		}
  2460  	case <-time.After(time.Second):
  2461  		t.Fatalf("Should have timed-out")
  2462  	}
  2463  }
  2464  
  2465  func TestWSCloseMsgSendOnConnectionClose(t *testing.T) {
  2466  	o := testWSOptions()
  2467  	s := RunServer(o)
  2468  	defer s.Shutdown()
  2469  
  2470  	wsc, br := testWSCreateClient(t, false, false, o.Websocket.Host, o.Websocket.Port)
  2471  	defer wsc.Close()
  2472  
  2473  	checkClientsCount(t, s, 1)
  2474  	var c *client
  2475  	s.mu.Lock()
  2476  	for _, cli := range s.clients {
  2477  		c = cli
  2478  		break
  2479  	}
  2480  	s.mu.Unlock()
  2481  
  2482  	c.closeConnection(ProtocolViolation)
  2483  	msg := testWSReadFrame(t, br)
  2484  	if len(msg) < 2 {
  2485  		t.Fatalf("Should have 2 bytes to represent the status, got %v", msg)
  2486  	}
  2487  	if sc := int(binary.BigEndian.Uint16(msg[:2])); sc != wsCloseStatusProtocolError {
  2488  		t.Fatalf("Expected status to be %v, got %v", wsCloseStatusProtocolError, sc)
  2489  	}
  2490  	expectedPayload := ProtocolViolation.String()
  2491  	if p := string(msg[2:]); p != expectedPayload {
  2492  		t.Fatalf("Expected payload to be %q, got %q", expectedPayload, p)
  2493  	}
  2494  }
  2495  
  2496  func TestWSAdvertise(t *testing.T) {
  2497  	o := testWSOptions()
  2498  	o.Cluster.Port = 0
  2499  	o.HTTPPort = 0
  2500  	o.Websocket.Advertise = "xxx:host:yyy"
  2501  	s, err := NewServer(o)
  2502  	if err != nil {
  2503  		t.Fatalf("Unexpected error: %v", err)
  2504  	}
  2505  	defer s.Shutdown()
  2506  	l := &captureFatalLogger{fatalCh: make(chan string, 1)}
  2507  	s.SetLogger(l, false, false)
  2508  	s.Start()
  2509  	select {
  2510  	case e := <-l.fatalCh:
  2511  		if !strings.Contains(e, "Unable to get websocket connect URLs") {
  2512  			t.Fatalf("Unexpected error: %q", e)
  2513  		}
  2514  	case <-time.After(time.Second):
  2515  		t.Fatal("Should have failed to start")
  2516  	}
  2517  	s.Shutdown()
  2518  
  2519  	o1 := testWSOptions()
  2520  	o1.Websocket.Advertise = "host1:1234"
  2521  	s1 := RunServer(o1)
  2522  	defer s1.Shutdown()
  2523  
  2524  	wsc, br := testWSCreateClient(t, false, false, o1.Websocket.Host, o1.Websocket.Port)
  2525  	defer wsc.Close()
  2526  
  2527  	o2 := testWSOptions()
  2528  	o2.Websocket.Advertise = "host2:5678"
  2529  	o2.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", o1.Cluster.Host, o1.Cluster.Port))
  2530  	s2 := RunServer(o2)
  2531  	defer s2.Shutdown()
  2532  
  2533  	checkInfo := func(expected []string) {
  2534  		t.Helper()
  2535  		infob := testWSReadFrame(t, br)
  2536  		info := &Info{}
  2537  		json.Unmarshal(infob[5:], info)
  2538  		if n := len(info.ClientConnectURLs); n != len(expected) {
  2539  			t.Fatalf("Unexpected info: %+v", info)
  2540  		}
  2541  		good := 0
  2542  		for _, u := range info.ClientConnectURLs {
  2543  			for _, eu := range expected {
  2544  				if u == eu {
  2545  					good++
  2546  				}
  2547  			}
  2548  		}
  2549  		if good != len(expected) {
  2550  			t.Fatalf("Unexpected connect urls: %q", info.ClientConnectURLs)
  2551  		}
  2552  	}
  2553  	checkInfo([]string{"host1:1234", "host2:5678"})
  2554  
  2555  	// Now shutdown s2 and expect another INFO
  2556  	s2.Shutdown()
  2557  	checkInfo([]string{"host1:1234"})
  2558  
  2559  	// Restart with another advertise and check that it gets updated
  2560  	o2.Websocket.Advertise = "host3:9012"
  2561  	s2 = RunServer(o2)
  2562  	defer s2.Shutdown()
  2563  	checkInfo([]string{"host1:1234", "host3:9012"})
  2564  }
  2565  
  2566  func TestWSFrameOutbound(t *testing.T) {
  2567  	for _, test := range []struct {
  2568  		name         string
  2569  		maskingWrite bool
  2570  	}{
  2571  		{"no write masking", false},
  2572  		{"write masking", true},
  2573  	} {
  2574  		t.Run(test.name, func(t *testing.T) {
  2575  			c, _, _ := testWSSetupForRead()
  2576  			c.ws.maskwrite = test.maskingWrite
  2577  
  2578  			getKey := func(buf []byte) []byte {
  2579  				return buf[len(buf)-4:]
  2580  			}
  2581  
  2582  			var bufs net.Buffers
  2583  			bufs = append(bufs, []byte("this "))
  2584  			bufs = append(bufs, []byte("is "))
  2585  			bufs = append(bufs, []byte("a "))
  2586  			bufs = append(bufs, []byte("set "))
  2587  			bufs = append(bufs, []byte("of "))
  2588  			bufs = append(bufs, []byte("buffers"))
  2589  			en := 2
  2590  			for _, b := range bufs {
  2591  				en += len(b)
  2592  			}
  2593  			if test.maskingWrite {
  2594  				en += 4
  2595  			}
  2596  			c.mu.Lock()
  2597  			c.out.nb = bufs
  2598  			res, n := c.collapsePtoNB()
  2599  			c.mu.Unlock()
  2600  			if n != int64(en) {
  2601  				t.Fatalf("Expected size to be %v, got %v", en, n)
  2602  			}
  2603  			if eb := 1 + len(bufs); eb != len(res) {
  2604  				t.Fatalf("Expected %v buffers, got %v", eb, len(res))
  2605  			}
  2606  			var ob []byte
  2607  			for i := 1; i < len(res); i++ {
  2608  				ob = append(ob, res[i]...)
  2609  			}
  2610  			if test.maskingWrite {
  2611  				wsMaskBuf(getKey(res[0]), ob)
  2612  			}
  2613  			if !bytes.Equal(ob, []byte("this is a set of buffers")) {
  2614  				t.Fatalf("Unexpected outbound: %q", ob)
  2615  			}
  2616  
  2617  			bufs = nil
  2618  			c.out.pb = 0
  2619  			c.ws.fs = 0
  2620  			c.ws.frames = nil
  2621  			c.ws.browser = true
  2622  			bufs = append(bufs, []byte("some smaller "))
  2623  			bufs = append(bufs, []byte("buffers"))
  2624  			bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+10))
  2625  			bufs = append(bufs, []byte("then some more"))
  2626  			en = 2 + len(bufs[0]) + len(bufs[1])
  2627  			en += 4 + len(bufs[2]) - 10
  2628  			en += 2 + len(bufs[3]) + 10
  2629  			c.mu.Lock()
  2630  			c.out.nb = bufs
  2631  			res, n = c.collapsePtoNB()
  2632  			c.mu.Unlock()
  2633  			if test.maskingWrite {
  2634  				en += 3 * 4
  2635  			}
  2636  			if n != int64(en) {
  2637  				t.Fatalf("Expected size to be %v, got %v", en, n)
  2638  			}
  2639  			if len(res) != 8 {
  2640  				t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
  2641  			}
  2642  			if len(res[4]) != wsFrameSizeForBrowsers {
  2643  				t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
  2644  			}
  2645  			if len(res[6]) != 10 {
  2646  				t.Fatalf("Frame 6 should have the partial of 10 bytes, got %v", len(res[6]))
  2647  			}
  2648  			if test.maskingWrite {
  2649  				b := &bytes.Buffer{}
  2650  				key := getKey(res[0])
  2651  				b.Write(res[1])
  2652  				b.Write(res[2])
  2653  				ud := b.Bytes()
  2654  				wsMaskBuf(key, ud)
  2655  				if string(ud) != "some smaller buffers" {
  2656  					t.Fatalf("Unexpected result: %q", ud)
  2657  				}
  2658  
  2659  				b.Reset()
  2660  				key = getKey(res[3])
  2661  				b.Write(res[4])
  2662  				ud = b.Bytes()
  2663  				wsMaskBuf(key, ud)
  2664  				for i := 0; i < len(ud); i++ {
  2665  					if ud[i] != 0 {
  2666  						t.Fatalf("Unexpected result: %v", ud)
  2667  					}
  2668  				}
  2669  
  2670  				b.Reset()
  2671  				key = getKey(res[5])
  2672  				b.Write(res[6])
  2673  				b.Write(res[7])
  2674  				ud = b.Bytes()
  2675  				wsMaskBuf(key, ud)
  2676  				for i := 0; i < len(ud[:10]); i++ {
  2677  					if ud[i] != 0 {
  2678  						t.Fatalf("Unexpected result: %v", ud[:10])
  2679  					}
  2680  				}
  2681  				if string(ud[10:]) != "then some more" {
  2682  					t.Fatalf("Unexpected result: %q", ud[10:])
  2683  				}
  2684  			}
  2685  
  2686  			bufs = nil
  2687  			c.out.pb = 0
  2688  			c.ws.fs = 0
  2689  			c.ws.frames = nil
  2690  			c.ws.browser = true
  2691  			bufs = append(bufs, []byte("some smaller "))
  2692  			bufs = append(bufs, []byte("buffers"))
  2693  			// Have one of the exact max size
  2694  			bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers))
  2695  			bufs = append(bufs, []byte("then some more"))
  2696  			en = 2 + len(bufs[0]) + len(bufs[1])
  2697  			en += 4 + len(bufs[2])
  2698  			en += 2 + len(bufs[3])
  2699  			c.mu.Lock()
  2700  			c.out.nb = bufs
  2701  			res, n = c.collapsePtoNB()
  2702  			c.mu.Unlock()
  2703  			if test.maskingWrite {
  2704  				en += 3 * 4
  2705  			}
  2706  			if n != int64(en) {
  2707  				t.Fatalf("Expected size to be %v, got %v", en, n)
  2708  			}
  2709  			if len(res) != 7 {
  2710  				t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
  2711  			}
  2712  			if len(res[4]) != wsFrameSizeForBrowsers {
  2713  				t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
  2714  			}
  2715  			if test.maskingWrite {
  2716  				key := getKey(res[5])
  2717  				wsMaskBuf(key, res[6])
  2718  			}
  2719  			if string(res[6]) != "then some more" {
  2720  				t.Fatalf("Frame 6 incorrect: %q", res[6])
  2721  			}
  2722  
  2723  			bufs = nil
  2724  			c.out.pb = 0
  2725  			c.ws.fs = 0
  2726  			c.ws.frames = nil
  2727  			c.ws.browser = true
  2728  			bufs = append(bufs, []byte("some smaller "))
  2729  			bufs = append(bufs, []byte("buffers"))
  2730  			// Have one of the exact max size, and last in the list
  2731  			bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers))
  2732  			en = 2 + len(bufs[0]) + len(bufs[1])
  2733  			en += 4 + len(bufs[2])
  2734  			c.mu.Lock()
  2735  			c.out.nb = bufs
  2736  			res, n = c.collapsePtoNB()
  2737  			c.mu.Unlock()
  2738  			if test.maskingWrite {
  2739  				en += 2 * 4
  2740  			}
  2741  			if n != int64(en) {
  2742  				t.Fatalf("Expected size to be %v, got %v", en, n)
  2743  			}
  2744  			if len(res) != 5 {
  2745  				t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
  2746  			}
  2747  			if len(res[4]) != wsFrameSizeForBrowsers {
  2748  				t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
  2749  			}
  2750  
  2751  			bufs = nil
  2752  			c.out.pb = 0
  2753  			c.ws.fs = 0
  2754  			c.ws.frames = nil
  2755  			c.ws.browser = true
  2756  			bufs = append(bufs, []byte("some smaller buffer"))
  2757  			bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers-5))
  2758  			bufs = append(bufs, []byte("then some more"))
  2759  			en = 2 + len(bufs[0])
  2760  			en += 4 + len(bufs[1])
  2761  			en += 2 + len(bufs[2])
  2762  			c.mu.Lock()
  2763  			c.out.nb = bufs
  2764  			res, n = c.collapsePtoNB()
  2765  			c.mu.Unlock()
  2766  			if test.maskingWrite {
  2767  				en += 3 * 4
  2768  			}
  2769  			if n != int64(en) {
  2770  				t.Fatalf("Expected size to be %v, got %v", en, n)
  2771  			}
  2772  			if len(res) != 6 {
  2773  				t.Fatalf("Unexpected number of outbound buffers: %v", len(res))
  2774  			}
  2775  			if len(res[3]) != wsFrameSizeForBrowsers-5 {
  2776  				t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4]))
  2777  			}
  2778  			if test.maskingWrite {
  2779  				key := getKey(res[4])
  2780  				wsMaskBuf(key, res[5])
  2781  			}
  2782  			if string(res[5]) != "then some more" {
  2783  				t.Fatalf("Frame 6 incorrect %q", res[5])
  2784  			}
  2785  
  2786  			bufs = nil
  2787  			c.out.pb = 0
  2788  			c.ws.fs = 0
  2789  			c.ws.frames = nil
  2790  			c.ws.browser = true
  2791  			bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+100))
  2792  			c.mu.Lock()
  2793  			c.out.nb = bufs
  2794  			res, _ = c.collapsePtoNB()
  2795  			c.mu.Unlock()
  2796  			if len(res) != 4 {
  2797  				t.Fatalf("Unexpected number of frames: %v", len(res))
  2798  			}
  2799  		})
  2800  	}
  2801  }
  2802  
  2803  func TestWSWebrowserClient(t *testing.T) {
  2804  	o := testWSOptions()
  2805  	s := RunServer(o)
  2806  	defer s.Shutdown()
  2807  
  2808  	wsc, br := testWSCreateClient(t, false, true, o.Websocket.Host, o.Websocket.Port)
  2809  	defer wsc.Close()
  2810  
  2811  	checkClientsCount(t, s, 1)
  2812  	var c *client
  2813  	s.mu.Lock()
  2814  	for _, cli := range s.clients {
  2815  		c = cli
  2816  		break
  2817  	}
  2818  	s.mu.Unlock()
  2819  
  2820  	proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB foo 1\r\nPING\r\n"))
  2821  	wsc.Write(proto)
  2822  	if res := testWSReadFrame(t, br); !bytes.Equal(res, []byte(pongProto)) {
  2823  		t.Fatalf("Expected PONG back")
  2824  	}
  2825  
  2826  	c.mu.Lock()
  2827  	ok := c.isWebsocket() && c.ws.browser == true
  2828  	c.mu.Unlock()
  2829  	if !ok {
  2830  		t.Fatalf("Client is not marked as webrowser client")
  2831  	}
  2832  
  2833  	nc := natsConnect(t, s.ClientURL())
  2834  	defer nc.Close()
  2835  
  2836  	// Send a big message and check that it is received in smaller frames
  2837  	psize := 204813
  2838  	nc.Publish("foo", make([]byte, psize))
  2839  	nc.Flush()
  2840  
  2841  	rsize := psize + len(fmt.Sprintf("MSG foo %d\r\n\r\n", psize))
  2842  	nframes := 0
  2843  	for total := 0; total < rsize; nframes++ {
  2844  		res := testWSReadFrame(t, br)
  2845  		total += len(res)
  2846  	}
  2847  	if expected := psize / wsFrameSizeForBrowsers; expected > nframes {
  2848  		t.Fatalf("Expected %v frames, got %v", expected, nframes)
  2849  	}
  2850  }
  2851  
  2852  type testWSWrappedConn struct {
  2853  	net.Conn
  2854  	mu      sync.RWMutex
  2855  	buf     *bytes.Buffer
  2856  	partial bool
  2857  }
  2858  
  2859  func (wc *testWSWrappedConn) Write(p []byte) (int, error) {
  2860  	wc.mu.Lock()
  2861  	defer wc.mu.Unlock()
  2862  	var err error
  2863  	n := len(p)
  2864  	if wc.partial && n > 10 {
  2865  		n = 10
  2866  		err = io.ErrShortWrite
  2867  	}
  2868  	p = p[:n]
  2869  	wc.buf.Write(p)
  2870  	wc.Conn.Write(p)
  2871  	return n, err
  2872  }
  2873  
  2874  func TestWSCompressionBasic(t *testing.T) {
  2875  	payload := "This is the content of a message that will be compresseddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd."
  2876  	msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
  2877  	cbuf := &bytes.Buffer{}
  2878  	compressor, err := flate.NewWriter(cbuf, flate.BestSpeed)
  2879  	require_NoError(t, err)
  2880  	compressor.Write([]byte(msgProto))
  2881  	compressor.Flush()
  2882  	compressed := cbuf.Bytes()
  2883  	// The last 4 bytes are dropped
  2884  	compressed = compressed[:len(compressed)-4]
  2885  
  2886  	o := testWSOptions()
  2887  	o.Websocket.Compression = true
  2888  	s := RunServer(o)
  2889  	defer s.Shutdown()
  2890  
  2891  	c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port)
  2892  	defer c.Close()
  2893  
  2894  	proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n"))
  2895  	c.Write(proto)
  2896  	l := testWSReadFrame(t, br)
  2897  	if !bytes.Equal(l, []byte(pongProto)) {
  2898  		t.Fatalf("Expected PONG, got %q", l)
  2899  	}
  2900  
  2901  	var wc *testWSWrappedConn
  2902  	s.mu.RLock()
  2903  	for _, c := range s.clients {
  2904  		c.mu.Lock()
  2905  		wc = &testWSWrappedConn{Conn: c.nc, buf: &bytes.Buffer{}}
  2906  		c.nc = wc
  2907  		c.mu.Unlock()
  2908  	}
  2909  	s.mu.RUnlock()
  2910  
  2911  	nc := natsConnect(t, s.ClientURL())
  2912  	defer nc.Close()
  2913  	natsPub(t, nc, "foo", []byte(payload))
  2914  
  2915  	res := &bytes.Buffer{}
  2916  	for total := 0; total < len(msgProto); {
  2917  		l := testWSReadFrame(t, br)
  2918  		n, _ := res.Write(l)
  2919  		total += n
  2920  	}
  2921  	if !bytes.Equal([]byte(msgProto), res.Bytes()) {
  2922  		t.Fatalf("Unexpected result: %q", res)
  2923  	}
  2924  
  2925  	// Now check the wrapped connection buffer to check that data was actually compressed.
  2926  	wc.mu.RLock()
  2927  	res = wc.buf
  2928  	wc.mu.RUnlock()
  2929  	if bytes.Contains(res.Bytes(), []byte(payload)) {
  2930  		t.Fatalf("Looks like frame was not compressed: %q", res.Bytes())
  2931  	}
  2932  	header := res.Bytes()[:2]
  2933  	body := res.Bytes()[2:]
  2934  	expectedB0 := byte(wsBinaryMessage) | wsFinalBit | wsRsv1Bit
  2935  	expectedPS := len(compressed)
  2936  	expectedB1 := byte(expectedPS)
  2937  
  2938  	if b := header[0]; b != expectedB0 {
  2939  		t.Fatalf("Expected first byte to be %v, got %v", expectedB0, b)
  2940  	}
  2941  	if b := header[1]; b != expectedB1 {
  2942  		t.Fatalf("Expected second byte to be %v, got %v", expectedB1, b)
  2943  	}
  2944  	if len(body) != expectedPS {
  2945  		t.Fatalf("Expected payload length to be %v, got %v", expectedPS, len(body))
  2946  	}
  2947  	if !bytes.Equal(body, compressed) {
  2948  		t.Fatalf("Unexpected compress body: %q", body)
  2949  	}
  2950  
  2951  	wc.mu.Lock()
  2952  	wc.buf.Reset()
  2953  	wc.mu.Unlock()
  2954  
  2955  	payload = "small"
  2956  	natsPub(t, nc, "foo", []byte(payload))
  2957  	msgProto = fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
  2958  	res = &bytes.Buffer{}
  2959  	for total := 0; total < len(msgProto); {
  2960  		l := testWSReadFrame(t, br)
  2961  		n, _ := res.Write(l)
  2962  		total += n
  2963  	}
  2964  	if !bytes.Equal([]byte(msgProto), res.Bytes()) {
  2965  		t.Fatalf("Unexpected result: %q", res)
  2966  	}
  2967  	wc.mu.RLock()
  2968  	res = wc.buf
  2969  	wc.mu.RUnlock()
  2970  	if !bytes.HasSuffix(res.Bytes(), []byte(msgProto)) {
  2971  		t.Fatalf("Looks like frame was compressed: %q", res.Bytes())
  2972  	}
  2973  }
  2974  
  2975  func TestWSCompressionWithPartialWrite(t *testing.T) {
  2976  	payload := "This is the content of a message that will be compresseddddddddddddddddddddd."
  2977  	msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload)
  2978  
  2979  	o := testWSOptions()
  2980  	o.Websocket.Compression = true
  2981  	s := RunServer(o)
  2982  	defer s.Shutdown()
  2983  
  2984  	c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port)
  2985  	defer c.Close()
  2986  
  2987  	proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n"))
  2988  	c.Write(proto)
  2989  	l := testWSReadFrame(t, br)
  2990  	if !bytes.Equal(l, []byte(pongProto)) {
  2991  		t.Fatalf("Expected PONG, got %q", l)
  2992  	}
  2993  
  2994  	pingPayload := []byte("my ping")
  2995  	pingFromWSClient := testWSCreateClientMsg(wsPingMessage, 1, true, false, pingPayload)
  2996  
  2997  	var wc *testWSWrappedConn
  2998  	var ws *client
  2999  	s.mu.Lock()
  3000  	for _, c := range s.clients {
  3001  		ws = c
  3002  		c.mu.Lock()
  3003  		wc = &testWSWrappedConn{
  3004  			Conn: c.nc,
  3005  			buf:  &bytes.Buffer{},
  3006  		}
  3007  		c.nc = wc
  3008  		c.mu.Unlock()
  3009  		break
  3010  	}
  3011  	s.mu.Unlock()
  3012  
  3013  	wc.mu.Lock()
  3014  	wc.partial = true
  3015  	wc.mu.Unlock()
  3016  
  3017  	nc := natsConnect(t, s.ClientURL())
  3018  	defer nc.Close()
  3019  
  3020  	expected := &bytes.Buffer{}
  3021  	for i := 0; i < 10; i++ {
  3022  		if i > 0 {
  3023  			time.Sleep(10 * time.Millisecond)
  3024  		}
  3025  		expected.Write([]byte(msgProto))
  3026  		natsPub(t, nc, "foo", []byte(payload))
  3027  		if i == 1 {
  3028  			c.Write(pingFromWSClient)
  3029  		}
  3030  	}
  3031  
  3032  	var gotPingResponse bool
  3033  	res := &bytes.Buffer{}
  3034  	for total := 0; total < 10*len(msgProto); {
  3035  		l := testWSReadFrame(t, br)
  3036  		if bytes.Equal(l, pingPayload) {
  3037  			gotPingResponse = true
  3038  		} else {
  3039  			n, _ := res.Write(l)
  3040  			total += n
  3041  		}
  3042  	}
  3043  	if !bytes.Equal(expected.Bytes(), res.Bytes()) {
  3044  		t.Fatalf("Unexpected result: %q", res)
  3045  	}
  3046  	if !gotPingResponse {
  3047  		t.Fatal("Did not get the ping response")
  3048  	}
  3049  
  3050  	checkFor(t, time.Second, 15*time.Millisecond, func() error {
  3051  		ws.mu.Lock()
  3052  		pb := ws.out.pb
  3053  		wf := ws.ws.frames
  3054  		fs := ws.ws.fs
  3055  		ws.mu.Unlock()
  3056  		if pb != 0 || len(wf) != 0 || fs != 0 {
  3057  			return fmt.Errorf("Expected pb, wf and fs to be 0, got %v, %v, %v", pb, wf, fs)
  3058  		}
  3059  		return nil
  3060  	})
  3061  }
  3062  
  3063  func TestWSCompressionFrameSizeLimit(t *testing.T) {
  3064  	for _, test := range []struct {
  3065  		name      string
  3066  		maskWrite bool
  3067  		noLimit   bool
  3068  	}{
  3069  		{"no write masking", false, false},
  3070  		{"write masking", true, false},
  3071  	} {
  3072  		t.Run(test.name, func(t *testing.T) {
  3073  			opts := testWSOptions()
  3074  			opts.MaxPending = MAX_PENDING_SIZE
  3075  			s := &Server{opts: opts}
  3076  			c := &client{srv: s, ws: &websocket{compress: true, browser: true, nocompfrag: test.noLimit, maskwrite: test.maskWrite}}
  3077  			c.initClient()
  3078  
  3079  			uncompressedPayload := make([]byte, 2*wsFrameSizeForBrowsers)
  3080  			for i := 0; i < len(uncompressedPayload); i++ {
  3081  				uncompressedPayload[i] = byte(rand.Intn(256))
  3082  			}
  3083  
  3084  			c.mu.Lock()
  3085  			c.out.nb = append(net.Buffers(nil), uncompressedPayload)
  3086  			nb, _ := c.collapsePtoNB()
  3087  			c.mu.Unlock()
  3088  
  3089  			if test.noLimit && len(nb) != 2 {
  3090  				t.Fatalf("There should be only 2 buffers, the header and payload, got %v", len(nb))
  3091  			}
  3092  
  3093  			bb := &bytes.Buffer{}
  3094  			var key []byte
  3095  			for i, b := range nb {
  3096  				if !test.noLimit {
  3097  					// frame header buffer are always very small. The payload should not be more
  3098  					// than 10 bytes since that is what we passed as the limit.
  3099  					if len(b) > wsFrameSizeForBrowsers {
  3100  						t.Fatalf("Frame size too big: %v (%q)", len(b), b)
  3101  					}
  3102  				}
  3103  				if test.maskWrite {
  3104  					if i%2 == 0 {
  3105  						key = b[len(b)-4:]
  3106  					} else {
  3107  						wsMaskBuf(key, b)
  3108  					}
  3109  				}
  3110  				// Check frame headers for the proper formatting.
  3111  				if i%2 == 0 {
  3112  					// Only the first frame should have the compress bit set.
  3113  					if b[0]&wsRsv1Bit != 0 {
  3114  						if i > 0 {
  3115  							t.Fatalf("Compressed bit should not be in continuation frame")
  3116  						}
  3117  					} else if i == 0 {
  3118  						t.Fatalf("Compressed bit missing")
  3119  					}
  3120  				} else {
  3121  					if test.noLimit {
  3122  						// Since the payload is likely not well compressed, we are expecting
  3123  						// the length to be > wsFrameSizeForBrowsers
  3124  						if len(b) <= wsFrameSizeForBrowsers {
  3125  							t.Fatalf("Expected frame to be bigger, got %v", len(b))
  3126  						}
  3127  					}
  3128  					// Collect the payload
  3129  					bb.Write(b)
  3130  				}
  3131  			}
  3132  			buf := bb.Bytes()
  3133  			buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff)
  3134  			dbr := bytes.NewBuffer(buf)
  3135  			d := flate.NewReader(dbr)
  3136  			uncompressed, err := io.ReadAll(d)
  3137  			if err != nil {
  3138  				t.Fatalf("Error reading frame: %v", err)
  3139  			}
  3140  			if !bytes.Equal(uncompressed, uncompressedPayload) {
  3141  				t.Fatalf("Unexpected uncomressed data: %q", uncompressed)
  3142  			}
  3143  		})
  3144  	}
  3145  }
  3146  
  3147  func TestWSBasicAuth(t *testing.T) {
  3148  	for _, test := range []struct {
  3149  		name    string
  3150  		opts    func() *Options
  3151  		user    string
  3152  		pass    string
  3153  		err     string
  3154  		cookies []string
  3155  	}{
  3156  		{
  3157  			"top level auth, no override, wrong u/p",
  3158  			func() *Options {
  3159  				o := testWSOptions()
  3160  				o.Username = "normal"
  3161  				o.Password = "client"
  3162  				return o
  3163  			},
  3164  			"websocket", "client", "-ERR 'Authorization Violation'",
  3165  			nil,
  3166  		},
  3167  		{
  3168  			"top level auth, no override, correct u/p",
  3169  			func() *Options {
  3170  				o := testWSOptions()
  3171  				o.Username = "normal"
  3172  				o.Password = "client"
  3173  				return o
  3174  			},
  3175  			"normal", "client", "",
  3176  			nil,
  3177  		},
  3178  		{
  3179  			"no top level auth, ws auth, wrong u/p",
  3180  			func() *Options {
  3181  				o := testWSOptions()
  3182  				o.Websocket.Username = "websocket"
  3183  				o.Websocket.Password = "client"
  3184  				return o
  3185  			},
  3186  			"normal", "client", "-ERR 'Authorization Violation'",
  3187  			nil,
  3188  		},
  3189  		{
  3190  			"no top level auth, ws auth, correct u/p",
  3191  			func() *Options {
  3192  				o := testWSOptions()
  3193  				o.Websocket.Username = "websocket"
  3194  				o.Websocket.Password = "client"
  3195  				return o
  3196  			},
  3197  			"websocket", "client", "",
  3198  			nil,
  3199  		},
  3200  		{
  3201  			"top level auth, ws override, wrong u/p",
  3202  			func() *Options {
  3203  				o := testWSOptions()
  3204  				o.Username = "normal"
  3205  				o.Password = "client"
  3206  				o.Websocket.Username = "websocket"
  3207  				o.Websocket.Password = "client"
  3208  				return o
  3209  			},
  3210  			"normal", "client", "-ERR 'Authorization Violation'",
  3211  			nil,
  3212  		},
  3213  		{
  3214  			"top level auth, ws override, correct u/p",
  3215  			func() *Options {
  3216  				o := testWSOptions()
  3217  				o.Username = "normal"
  3218  				o.Password = "client"
  3219  				o.Websocket.Username = "websocket"
  3220  				o.Websocket.Password = "client"
  3221  				return o
  3222  			},
  3223  			"websocket", "client", "",
  3224  			nil,
  3225  		},
  3226  		{
  3227  			"username/password from cookies",
  3228  			func() *Options {
  3229  				o := testWSOptions()
  3230  				o.Websocket.UsernameCookie = "un"
  3231  				o.Websocket.PasswordCookie = "pw"
  3232  				o.Username = "me"
  3233  				o.Password = "s3cr3t!"
  3234  				return o
  3235  			},
  3236  			"", "", "",
  3237  			[]string{"un=me", "pw=s3cr3t!"},
  3238  		},
  3239  		{
  3240  			"bad username/ good password from cookies",
  3241  			func() *Options {
  3242  				o := testWSOptions()
  3243  				o.Websocket.UsernameCookie = "un"
  3244  				o.Websocket.PasswordCookie = "pw"
  3245  				o.Username = "me"
  3246  				o.Password = "s3cr3t!"
  3247  				return o
  3248  			},
  3249  			"", "", "-ERR 'Authorization Violation",
  3250  			[]string{"un=m", "pw=s3cr3t!"},
  3251  		},
  3252  		{
  3253  			"good username/ bad password from cookies",
  3254  			func() *Options {
  3255  				o := testWSOptions()
  3256  				o.Websocket.UsernameCookie = "un"
  3257  				o.Websocket.PasswordCookie = "pw"
  3258  				o.Username = "me"
  3259  				o.Password = "s3cr3t!"
  3260  				return o
  3261  			},
  3262  			"", "", "-ERR 'Authorization Violation",
  3263  			[]string{"un=me", "pw=hi!"},
  3264  		},
  3265  		{
  3266  			"token from cookie",
  3267  			func() *Options {
  3268  				o := testWSOptions()
  3269  				o.Websocket.TokenCookie = "tok"
  3270  				o.Authorization = "l3tm31n!"
  3271  				return o
  3272  			},
  3273  			"", "", "",
  3274  			[]string{"tok=l3tm31n!"},
  3275  		},
  3276  		{
  3277  			"bad token from cookie",
  3278  			func() *Options {
  3279  				o := testWSOptions()
  3280  				o.Websocket.TokenCookie = "tok"
  3281  				o.Authorization = "l3tm31n!"
  3282  				return o
  3283  			},
  3284  			"", "", "-ERR 'Authorization Violation",
  3285  			[]string{"tok=hello!"},
  3286  		},
  3287  	} {
  3288  		t.Run(test.name, func(t *testing.T) {
  3289  			o := test.opts()
  3290  			s := RunServer(o)
  3291  			defer s.Shutdown()
  3292  
  3293  			wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port, test.cookies...)
  3294  			defer wsc.Close()
  3295  
  3296  			connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n",
  3297  				test.user, test.pass)
  3298  
  3299  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3300  			if _, err := wsc.Write(wsmsg); err != nil {
  3301  				t.Fatalf("Error sending message: %v", err)
  3302  			}
  3303  			msg := testWSReadFrame(t, br)
  3304  			if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3305  				t.Fatalf("Expected to receive PONG, got %q", msg)
  3306  			} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
  3307  				t.Fatalf("Expected to receive %q, got %q", test.err, msg)
  3308  			}
  3309  		})
  3310  	}
  3311  }
  3312  
  3313  func TestWSAuthTimeout(t *testing.T) {
  3314  	for _, test := range []struct {
  3315  		name string
  3316  		at   float64
  3317  		wat  float64
  3318  		err  string
  3319  	}{
  3320  		{"use top-level auth timeout", 10.0, 0.0, ""},
  3321  		{"use websocket auth timeout", 10.0, 0.05, "-ERR 'Authentication Timeout'"},
  3322  	} {
  3323  		t.Run(test.name, func(t *testing.T) {
  3324  			o := testWSOptions()
  3325  			o.AuthTimeout = test.at
  3326  			o.Websocket.Username = "websocket"
  3327  			o.Websocket.Password = "client"
  3328  			o.Websocket.AuthTimeout = test.wat
  3329  			s := RunServer(o)
  3330  			defer s.Shutdown()
  3331  
  3332  			wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
  3333  			defer wsc.Close()
  3334  
  3335  			var info serverInfo
  3336  			json.Unmarshal([]byte(l[5:]), &info)
  3337  			// Make sure that we are told that auth is required.
  3338  			if !info.AuthRequired {
  3339  				t.Fatalf("Expected auth required, was not: %q", l)
  3340  			}
  3341  			start := time.Now()
  3342  			// Wait before sending connect
  3343  			time.Sleep(100 * time.Millisecond)
  3344  			connectProto := "CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"websocket\",\"pass\":\"client\"}\r\nPING\r\n"
  3345  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3346  			if _, err := wsc.Write(wsmsg); err != nil {
  3347  				t.Fatalf("Error sending message: %v", err)
  3348  			}
  3349  			msg := testWSReadFrame(t, br)
  3350  			if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
  3351  				t.Fatalf("Expected to receive %q error, got %q", test.err, msg)
  3352  			} else if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3353  				t.Fatalf("Unexpected error: %q", msg)
  3354  			}
  3355  			if dur := time.Since(start); dur > time.Second {
  3356  				t.Fatalf("Too long to get timeout error: %v", dur)
  3357  			}
  3358  		})
  3359  	}
  3360  }
  3361  
  3362  func TestWSTokenAuth(t *testing.T) {
  3363  	for _, test := range []struct {
  3364  		name  string
  3365  		opts  func() *Options
  3366  		token string
  3367  		err   string
  3368  	}{
  3369  		{
  3370  			"top level auth, no override, wrong token",
  3371  			func() *Options {
  3372  				o := testWSOptions()
  3373  				o.Authorization = "goodtoken"
  3374  				return o
  3375  			},
  3376  			"badtoken", "-ERR 'Authorization Violation'",
  3377  		},
  3378  		{
  3379  			"top level auth, no override, correct token",
  3380  			func() *Options {
  3381  				o := testWSOptions()
  3382  				o.Authorization = "goodtoken"
  3383  				return o
  3384  			},
  3385  			"goodtoken", "",
  3386  		},
  3387  		{
  3388  			"no top level auth, ws auth, wrong token",
  3389  			func() *Options {
  3390  				o := testWSOptions()
  3391  				o.Websocket.Token = "goodtoken"
  3392  				return o
  3393  			},
  3394  			"badtoken", "-ERR 'Authorization Violation'",
  3395  		},
  3396  		{
  3397  			"no top level auth, ws auth, correct token",
  3398  			func() *Options {
  3399  				o := testWSOptions()
  3400  				o.Websocket.Token = "goodtoken"
  3401  				return o
  3402  			},
  3403  			"goodtoken", "",
  3404  		},
  3405  		{
  3406  			"top level auth, ws override, wrong token",
  3407  			func() *Options {
  3408  				o := testWSOptions()
  3409  				o.Authorization = "clienttoken"
  3410  				o.Websocket.Token = "websockettoken"
  3411  				return o
  3412  			},
  3413  			"clienttoken", "-ERR 'Authorization Violation'",
  3414  		},
  3415  		{
  3416  			"top level auth, ws override, correct token",
  3417  			func() *Options {
  3418  				o := testWSOptions()
  3419  				o.Authorization = "clienttoken"
  3420  				o.Websocket.Token = "websockettoken"
  3421  				return o
  3422  			},
  3423  			"websockettoken", "",
  3424  		},
  3425  	} {
  3426  		t.Run(test.name, func(t *testing.T) {
  3427  			o := test.opts()
  3428  			s := RunServer(o)
  3429  			defer s.Shutdown()
  3430  
  3431  			wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
  3432  			defer wsc.Close()
  3433  
  3434  			connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"auth_token\":\"%s\"}\r\nPING\r\n",
  3435  				test.token)
  3436  
  3437  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3438  			if _, err := wsc.Write(wsmsg); err != nil {
  3439  				t.Fatalf("Error sending message: %v", err)
  3440  			}
  3441  			msg := testWSReadFrame(t, br)
  3442  			if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3443  				t.Fatalf("Expected to receive PONG, got %q", msg)
  3444  			} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
  3445  				t.Fatalf("Expected to receive %q, got %q", test.err, msg)
  3446  			}
  3447  		})
  3448  	}
  3449  }
  3450  
  3451  func TestWSBindToProperAccount(t *testing.T) {
  3452  	conf := createConfFile(t, []byte(fmt.Sprintf(`
  3453  		listen: "127.0.0.1:-1"
  3454  		accounts {
  3455  			a {
  3456  				users [
  3457  					{user: a, password: pwd, allowed_connection_types: ["%s", "%s"]}
  3458  				]
  3459  			}
  3460  			b {
  3461  				users [
  3462  					{user: b, password: pwd}
  3463  				]
  3464  			}
  3465  		}
  3466  		websocket {
  3467  			listen: "127.0.0.1:-1"
  3468  			no_tls: true
  3469  		}
  3470  	`, jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)))) // on purpose use lower case to ensure that it is converted.
  3471  	s, o := RunServerWithConfig(conf)
  3472  	defer s.Shutdown()
  3473  
  3474  	nc := natsConnect(t, fmt.Sprintf("nats://a:pwd@127.0.0.1:%d", o.Port))
  3475  	defer nc.Close()
  3476  
  3477  	sub := natsSubSync(t, nc, "foo")
  3478  
  3479  	wsc, br, _ := testNewWSClient(t, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port, noTLS: true})
  3480  	// Send CONNECT and PING
  3481  	wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false,
  3482  		[]byte(fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", "a", "pwd")))
  3483  	if _, err := wsc.Write(wsmsg); err != nil {
  3484  		t.Fatalf("Error sending message: %v", err)
  3485  	}
  3486  	// Wait for the PONG
  3487  	if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3488  		t.Fatalf("Expected PONG, got %s", msg)
  3489  	}
  3490  
  3491  	wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n"))
  3492  	if _, err := wsc.Write(wsmsg); err != nil {
  3493  		t.Fatalf("Error sending message: %v", err)
  3494  	}
  3495  
  3496  	natsNexMsg(t, sub, time.Second)
  3497  }
  3498  
  3499  func TestWSUsersAuth(t *testing.T) {
  3500  	users := []*User{{Username: "user", Password: "pwd"}}
  3501  	for _, test := range []struct {
  3502  		name string
  3503  		opts func() *Options
  3504  		user string
  3505  		pass string
  3506  		err  string
  3507  	}{
  3508  		{
  3509  			"no filtering, wrong user",
  3510  			func() *Options {
  3511  				o := testWSOptions()
  3512  				o.Users = users
  3513  				return o
  3514  			},
  3515  			"wronguser", "pwd", "-ERR 'Authorization Violation'",
  3516  		},
  3517  		{
  3518  			"no filtering, correct user",
  3519  			func() *Options {
  3520  				o := testWSOptions()
  3521  				o.Users = users
  3522  				return o
  3523  			},
  3524  			"user", "pwd", "",
  3525  		},
  3526  		{
  3527  			"filering, user not allowed",
  3528  			func() *Options {
  3529  				o := testWSOptions()
  3530  				o.Users = users
  3531  				// Only allowed for regular clients
  3532  				o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard})
  3533  				return o
  3534  			},
  3535  			"user", "pwd", "-ERR 'Authorization Violation'",
  3536  		},
  3537  		{
  3538  			"filtering, user allowed",
  3539  			func() *Options {
  3540  				o := testWSOptions()
  3541  				o.Users = users
  3542  				o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket})
  3543  				return o
  3544  			},
  3545  			"user", "pwd", "",
  3546  		},
  3547  		{
  3548  			"filtering, wrong password",
  3549  			func() *Options {
  3550  				o := testWSOptions()
  3551  				o.Users = users
  3552  				o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket})
  3553  				return o
  3554  			},
  3555  			"user", "badpassword", "-ERR 'Authorization Violation'",
  3556  		},
  3557  	} {
  3558  		t.Run(test.name, func(t *testing.T) {
  3559  			o := test.opts()
  3560  			s := RunServer(o)
  3561  			defer s.Shutdown()
  3562  
  3563  			wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
  3564  			defer wsc.Close()
  3565  
  3566  			connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n",
  3567  				test.user, test.pass)
  3568  
  3569  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3570  			if _, err := wsc.Write(wsmsg); err != nil {
  3571  				t.Fatalf("Error sending message: %v", err)
  3572  			}
  3573  			msg := testWSReadFrame(t, br)
  3574  			if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3575  				t.Fatalf("Expected to receive PONG, got %q", msg)
  3576  			} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
  3577  				t.Fatalf("Expected to receive %q, got %q", test.err, msg)
  3578  			}
  3579  		})
  3580  	}
  3581  }
  3582  
  3583  func TestWSNoAuthUserValidation(t *testing.T) {
  3584  	o := testWSOptions()
  3585  	o.Users = []*User{{Username: "user", Password: "pwd"}}
  3586  	// Should fail because it is not part of o.Users.
  3587  	o.Websocket.NoAuthUser = "notfound"
  3588  	if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") {
  3589  		t.Fatalf("Expected error saying not present as user, got %v", err)
  3590  	}
  3591  	// Set a valid no auth user for global options, but still should fail because
  3592  	// of o.Websocket.NoAuthUser
  3593  	o.NoAuthUser = "user"
  3594  	o.Websocket.NoAuthUser = "notfound"
  3595  	if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") {
  3596  		t.Fatalf("Expected error saying not present as user, got %v", err)
  3597  	}
  3598  }
  3599  
  3600  func TestWSNoAuthUser(t *testing.T) {
  3601  	for _, test := range []struct {
  3602  		name         string
  3603  		override     bool
  3604  		useAuth      bool
  3605  		expectedUser string
  3606  		expectedAcc  string
  3607  	}{
  3608  		{"no override, no user provided", false, false, "noauth", "normal"},
  3609  		{"no override, user povided", false, true, "user", "normal"},
  3610  		{"override, no user provided", true, false, "wsnoauth", "websocket"},
  3611  		{"override, user provided", true, true, "wsuser", "websocket"},
  3612  	} {
  3613  		t.Run(test.name, func(t *testing.T) {
  3614  			o := testWSOptions()
  3615  			normalAcc := NewAccount("normal")
  3616  			websocketAcc := NewAccount("websocket")
  3617  			o.Accounts = []*Account{normalAcc, websocketAcc}
  3618  			o.Users = []*User{
  3619  				{Username: "noauth", Password: "pwd", Account: normalAcc},
  3620  				{Username: "user", Password: "pwd", Account: normalAcc},
  3621  				{Username: "wsnoauth", Password: "pwd", Account: websocketAcc},
  3622  				{Username: "wsuser", Password: "pwd", Account: websocketAcc},
  3623  			}
  3624  			o.NoAuthUser = "noauth"
  3625  			if test.override {
  3626  				o.Websocket.NoAuthUser = "wsnoauth"
  3627  			}
  3628  			s := RunServer(o)
  3629  			defer s.Shutdown()
  3630  
  3631  			wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
  3632  			defer wsc.Close()
  3633  
  3634  			var info serverInfo
  3635  			json.Unmarshal([]byte(l[5:]), &info)
  3636  
  3637  			var connectProto string
  3638  			if test.useAuth {
  3639  				connectProto = fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"pwd\"}\r\nPING\r\n",
  3640  					test.expectedUser)
  3641  			} else {
  3642  				connectProto = "CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"
  3643  			}
  3644  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3645  			if _, err := wsc.Write(wsmsg); err != nil {
  3646  				t.Fatalf("Error sending message: %v", err)
  3647  			}
  3648  			msg := testWSReadFrame(t, br)
  3649  			if !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3650  				t.Fatalf("Unexpected error: %q", msg)
  3651  			}
  3652  
  3653  			c := s.getClient(info.CID)
  3654  			c.mu.Lock()
  3655  			uname := c.opts.Username
  3656  			aname := c.acc.GetName()
  3657  			c.mu.Unlock()
  3658  			if uname != test.expectedUser {
  3659  				t.Fatalf("Expected selected user to be %q, got %q", test.expectedUser, uname)
  3660  			}
  3661  			if aname != test.expectedAcc {
  3662  				t.Fatalf("Expected selected account to be %q, got %q", test.expectedAcc, aname)
  3663  			}
  3664  		})
  3665  	}
  3666  }
  3667  
  3668  func TestWSNkeyAuth(t *testing.T) {
  3669  	nkp, _ := nkeys.CreateUser()
  3670  	pub, _ := nkp.PublicKey()
  3671  
  3672  	wsnkp, _ := nkeys.CreateUser()
  3673  	wspub, _ := wsnkp.PublicKey()
  3674  
  3675  	badkp, _ := nkeys.CreateUser()
  3676  	badpub, _ := badkp.PublicKey()
  3677  
  3678  	for _, test := range []struct {
  3679  		name string
  3680  		opts func() *Options
  3681  		nkey string
  3682  		kp   nkeys.KeyPair
  3683  		err  string
  3684  	}{
  3685  		{
  3686  			"no filtering, wrong nkey",
  3687  			func() *Options {
  3688  				o := testWSOptions()
  3689  				o.Nkeys = []*NkeyUser{{Nkey: pub}}
  3690  				return o
  3691  			},
  3692  			badpub, badkp, "-ERR 'Authorization Violation'",
  3693  		},
  3694  		{
  3695  			"no filtering, correct nkey",
  3696  			func() *Options {
  3697  				o := testWSOptions()
  3698  				o.Nkeys = []*NkeyUser{{Nkey: pub}}
  3699  				return o
  3700  			},
  3701  			pub, nkp, "",
  3702  		},
  3703  		{
  3704  			"filtering, nkey not allowed",
  3705  			func() *Options {
  3706  				o := testWSOptions()
  3707  				o.Nkeys = []*NkeyUser{
  3708  					{
  3709  						Nkey:                   pub,
  3710  						AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}),
  3711  					},
  3712  					{
  3713  						Nkey:                   wspub,
  3714  						AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeWebsocket}),
  3715  					},
  3716  				}
  3717  				return o
  3718  			},
  3719  			pub, nkp, "-ERR 'Authorization Violation'",
  3720  		},
  3721  		{
  3722  			"filtering, correct nkey",
  3723  			func() *Options {
  3724  				o := testWSOptions()
  3725  				o.Nkeys = []*NkeyUser{
  3726  					{Nkey: pub},
  3727  					{
  3728  						Nkey:                   wspub,
  3729  						AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}),
  3730  					},
  3731  				}
  3732  				return o
  3733  			},
  3734  			wspub, wsnkp, "",
  3735  		},
  3736  		{
  3737  			"filtering, wrong nkey",
  3738  			func() *Options {
  3739  				o := testWSOptions()
  3740  				o.Nkeys = []*NkeyUser{
  3741  					{
  3742  						Nkey:                   wspub,
  3743  						AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}),
  3744  					},
  3745  				}
  3746  				return o
  3747  			},
  3748  			badpub, badkp, "-ERR 'Authorization Violation'",
  3749  		},
  3750  	} {
  3751  		t.Run(test.name, func(t *testing.T) {
  3752  			o := test.opts()
  3753  			s := RunServer(o)
  3754  			defer s.Shutdown()
  3755  
  3756  			wsc, br, infoMsg := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port)
  3757  			defer wsc.Close()
  3758  
  3759  			// Sign Nonce
  3760  			var info nonceInfo
  3761  			json.Unmarshal([]byte(infoMsg[5:]), &info)
  3762  			sigraw, _ := test.kp.Sign([]byte(info.Nonce))
  3763  			sig := base64.RawURLEncoding.EncodeToString(sigraw)
  3764  
  3765  			connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"nkey\":\"%s\",\"sig\":\"%s\"}\r\nPING\r\n", test.nkey, sig)
  3766  
  3767  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto))
  3768  			if _, err := wsc.Write(wsmsg); err != nil {
  3769  				t.Fatalf("Error sending message: %v", err)
  3770  			}
  3771  			msg := testWSReadFrame(t, br)
  3772  			if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  3773  				t.Fatalf("Expected to receive PONG, got %q", msg)
  3774  			} else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) {
  3775  				t.Fatalf("Expected to receive %q, got %q", test.err, msg)
  3776  			}
  3777  		})
  3778  	}
  3779  }
  3780  
  3781  func TestWSJWTWithAllowedConnectionTypes(t *testing.T) {
  3782  	o := testWSOptions()
  3783  	setupAddTrusted(o)
  3784  	s := RunServer(o)
  3785  	buildMemAccResolver(s)
  3786  	defer s.Shutdown()
  3787  
  3788  	for _, test := range []struct {
  3789  		name            string
  3790  		connectionTypes []string
  3791  		expectedAnswer  string
  3792  	}{
  3793  		{"not allowed", []string{jwt.ConnectionTypeStandard}, "-ERR"},
  3794  		{"allowed", []string{jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)}, "+OK"},
  3795  		{"allowed with unknown", []string{jwt.ConnectionTypeWebsocket, "SomeNewType"}, "+OK"},
  3796  		{"not allowed with unknown", []string{"SomeNewType"}, "-ERR"},
  3797  	} {
  3798  		t.Run(test.name, func(t *testing.T) {
  3799  			nuc := newJWTTestUserClaims()
  3800  			nuc.AllowedConnectionTypes = test.connectionTypes
  3801  			claimOpt := testClaimsOptions{
  3802  				nuc:          nuc,
  3803  				expectAnswer: test.expectedAnswer,
  3804  			}
  3805  			_, c, _, _ := testWSWithClaims(t, s, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port}, claimOpt)
  3806  			c.Close()
  3807  		})
  3808  	}
  3809  }
  3810  
  3811  func TestWSJWTCookieUser(t *testing.T) {
  3812  	nucSigFunc := func() *jwt.UserClaims { return newJWTTestUserClaims() }
  3813  	nucBearerFunc := func() *jwt.UserClaims {
  3814  		ret := newJWTTestUserClaims()
  3815  		ret.BearerToken = true
  3816  		return ret
  3817  	}
  3818  
  3819  	o := testWSOptions()
  3820  	setupAddTrusted(o)
  3821  	setupAddCookie(o)
  3822  	s := RunServer(o)
  3823  	buildMemAccResolver(s)
  3824  	defer s.Shutdown()
  3825  
  3826  	genJwt := func(t *testing.T, nuc *jwt.UserClaims) string {
  3827  		okp, _ := nkeys.FromSeed(oSeed)
  3828  
  3829  		akp, _ := nkeys.CreateAccount()
  3830  		apub, _ := akp.PublicKey()
  3831  
  3832  		nac := jwt.NewAccountClaims(apub)
  3833  		ajwt, err := nac.Encode(okp)
  3834  		if err != nil {
  3835  			t.Fatalf("Error generating account JWT: %v", err)
  3836  		}
  3837  
  3838  		nkp, _ := nkeys.CreateUser()
  3839  		pub, _ := nkp.PublicKey()
  3840  		nuc.Subject = pub
  3841  		jwt, err := nuc.Encode(akp)
  3842  		if err != nil {
  3843  			t.Fatalf("Error generating user JWT: %v", err)
  3844  		}
  3845  		addAccountToMemResolver(s, apub, ajwt)
  3846  		return jwt
  3847  	}
  3848  
  3849  	cliOpts := testWSClientOptions{
  3850  		host: o.Websocket.Host,
  3851  		port: o.Websocket.Port,
  3852  	}
  3853  	for _, test := range []struct {
  3854  		name         string
  3855  		nuc          *jwt.UserClaims
  3856  		opts         func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions)
  3857  		expectAnswer string
  3858  	}{
  3859  		{
  3860  			name: "protocol auth, non-bearer key, with signature",
  3861  			nuc:  nucSigFunc(),
  3862  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  3863  				return cliOpts, testClaimsOptions{nuc: claims}
  3864  			},
  3865  			expectAnswer: "+OK",
  3866  		},
  3867  		{
  3868  			name: "protocol auth, non-bearer key, w/o required signature",
  3869  			nuc:  nucSigFunc(),
  3870  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  3871  				return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
  3872  			},
  3873  			expectAnswer: "-ERR",
  3874  		},
  3875  		{
  3876  			name: "protocol auth, bearer key, w/o signature",
  3877  			nuc:  nucBearerFunc(),
  3878  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  3879  				return cliOpts, testClaimsOptions{nuc: claims, dontSign: true}
  3880  			},
  3881  			expectAnswer: "+OK",
  3882  		},
  3883  		{
  3884  			name: "cookie auth, non-bearer key, protocol auth fail",
  3885  			nuc:  nucSigFunc(),
  3886  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  3887  				co := cliOpts
  3888  				co.extraHeaders = map[string][]string{}
  3889  				co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)}
  3890  				return co, testClaimsOptions{connectRequest: struct{}{}}
  3891  			},
  3892  			expectAnswer: "-ERR",
  3893  		},
  3894  		{
  3895  			name: "cookie auth, bearer key, protocol auth success with implied cookie jwt",
  3896  			nuc:  nucBearerFunc(),
  3897  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  3898  				co := cliOpts
  3899  				co.extraHeaders = map[string][]string{}
  3900  				co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)}
  3901  				return co, testClaimsOptions{connectRequest: struct{}{}}
  3902  			},
  3903  			expectAnswer: "+OK",
  3904  		},
  3905  		{
  3906  			name: "cookie auth, non-bearer key, protocol auth success via override jwt in CONNECT opts",
  3907  			nuc:  nucSigFunc(),
  3908  			opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) {
  3909  				co := cliOpts
  3910  				co.extraHeaders = map[string][]string{}
  3911  				co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)}
  3912  				return co, testClaimsOptions{nuc: nucBearerFunc()}
  3913  			},
  3914  			expectAnswer: "+OK",
  3915  		},
  3916  	} {
  3917  		t.Run(test.name, func(t *testing.T) {
  3918  			cliOpt, claimOpt := test.opts(t, test.nuc)
  3919  			claimOpt.expectAnswer = test.expectAnswer
  3920  			_, c, _, _ := testWSWithClaims(t, s, cliOpt, claimOpt)
  3921  			c.Close()
  3922  		})
  3923  	}
  3924  	s.Shutdown()
  3925  }
  3926  
  3927  func TestWSReloadTLSConfig(t *testing.T) {
  3928  	template := `
  3929  		listen: "127.0.0.1:-1"
  3930  		websocket {
  3931  			listen: "127.0.0.1:-1"
  3932  			tls {
  3933  				cert_file: '%s'
  3934  				key_file: '%s'
  3935  				ca_file: '../test/configs/certs/ca.pem'
  3936  			}
  3937  		}
  3938  	`
  3939  	conf := createConfFile(t, []byte(fmt.Sprintf(template,
  3940  		"../test/configs/certs/server-noip.pem",
  3941  		"../test/configs/certs/server-key-noip.pem")))
  3942  
  3943  	s, o := RunServerWithConfig(conf)
  3944  	defer s.Shutdown()
  3945  
  3946  	addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port)
  3947  	wsc, err := net.Dial("tcp", addr)
  3948  	if err != nil {
  3949  		t.Fatalf("Error creating ws connection: %v", err)
  3950  	}
  3951  	defer wsc.Close()
  3952  
  3953  	tc := &TLSConfigOpts{CaFile: "../test/configs/certs/ca.pem"}
  3954  	tlsConfig, err := GenTLSConfig(tc)
  3955  	if err != nil {
  3956  		t.Fatalf("Error generating TLS config: %v", err)
  3957  	}
  3958  	tlsConfig.ServerName = "127.0.0.1"
  3959  	tlsConfig.RootCAs = tlsConfig.ClientCAs
  3960  	tlsConfig.ClientCAs = nil
  3961  	wsc = tls.Client(wsc, tlsConfig.Clone())
  3962  	if err := wsc.(*tls.Conn).Handshake(); err == nil || !strings.Contains(err.Error(), "SAN") {
  3963  		t.Fatalf("Unexpected error: %v", err)
  3964  	}
  3965  	wsc.Close()
  3966  
  3967  	reloadUpdateConfig(t, s, conf, fmt.Sprintf(template,
  3968  		"../test/configs/certs/server-cert.pem",
  3969  		"../test/configs/certs/server-key.pem"))
  3970  
  3971  	wsc, err = net.Dial("tcp", addr)
  3972  	if err != nil {
  3973  		t.Fatalf("Error creating ws connection: %v", err)
  3974  	}
  3975  	defer wsc.Close()
  3976  
  3977  	wsc = tls.Client(wsc, tlsConfig.Clone())
  3978  	if err := wsc.(*tls.Conn).Handshake(); err != nil {
  3979  		t.Fatalf("Error on TLS handshake: %v", err)
  3980  	}
  3981  }
  3982  
  3983  type captureClientConnectedLogger struct {
  3984  	DummyLogger
  3985  	ch chan string
  3986  }
  3987  
  3988  func (l *captureClientConnectedLogger) Debugf(format string, v ...interface{}) {
  3989  	msg := fmt.Sprintf(format, v...)
  3990  	if !strings.Contains(msg, "Client connection created") {
  3991  		return
  3992  	}
  3993  	select {
  3994  	case l.ch <- msg:
  3995  	default:
  3996  	}
  3997  }
  3998  
  3999  func TestWSXForwardedFor(t *testing.T) {
  4000  	o := testWSOptions()
  4001  	s := RunServer(o)
  4002  	defer s.Shutdown()
  4003  
  4004  	l := &captureClientConnectedLogger{ch: make(chan string, 1)}
  4005  	s.SetLogger(l, true, false)
  4006  
  4007  	for _, test := range []struct {
  4008  		name          string
  4009  		headers       func() map[string][]string
  4010  		useHdrValue   bool
  4011  		expectedValue string
  4012  	}{
  4013  		{"nil map", func() map[string][]string {
  4014  			return nil
  4015  		}, false, _EMPTY_},
  4016  		{"empty map", func() map[string][]string {
  4017  			return make(map[string][]string)
  4018  		}, false, _EMPTY_},
  4019  		{"header present empty value", func() map[string][]string {
  4020  			m := make(map[string][]string)
  4021  			m[wsXForwardedForHeader] = []string{}
  4022  			return m
  4023  		}, false, _EMPTY_},
  4024  		{"header present invalid IP", func() map[string][]string {
  4025  			m := make(map[string][]string)
  4026  			m[wsXForwardedForHeader] = []string{"not a valid IP"}
  4027  			return m
  4028  		}, false, _EMPTY_},
  4029  		{"header present one IP", func() map[string][]string {
  4030  			m := make(map[string][]string)
  4031  			m[wsXForwardedForHeader] = []string{"1.2.3.4"}
  4032  			return m
  4033  		}, true, "1.2.3.4"},
  4034  		{"header present multiple IPs", func() map[string][]string {
  4035  			m := make(map[string][]string)
  4036  			m[wsXForwardedForHeader] = []string{"1.2.3.4", "5.6.7.8"}
  4037  			return m
  4038  		}, true, "1.2.3.4"},
  4039  		{"header present IPv6", func() map[string][]string {
  4040  			m := make(map[string][]string)
  4041  			m[wsXForwardedForHeader] = []string{"::1"}
  4042  			return m
  4043  		}, true, "[::1]"},
  4044  	} {
  4045  		t.Run(test.name, func(t *testing.T) {
  4046  			c, r, _ := testNewWSClient(t, testWSClientOptions{
  4047  				host:         o.Websocket.Host,
  4048  				port:         o.Websocket.Port,
  4049  				extraHeaders: test.headers(),
  4050  			})
  4051  			defer c.Close()
  4052  			// Send CONNECT and PING
  4053  			wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n"))
  4054  			if _, err := c.Write(wsmsg); err != nil {
  4055  				t.Fatalf("Error sending message: %v", err)
  4056  			}
  4057  			// Wait for the PONG
  4058  			if msg := testWSReadFrame(t, r); !bytes.HasPrefix(msg, []byte("PONG\r\n")) {
  4059  				t.Fatalf("Expected PONG, got %s", msg)
  4060  			}
  4061  			select {
  4062  			case d := <-l.ch:
  4063  				ipAndSlash := fmt.Sprintf("%s/", test.expectedValue)
  4064  				if test.useHdrValue {
  4065  					if !strings.HasPrefix(d, ipAndSlash) {
  4066  						t.Fatalf("Expected debug statement to start with: %q, got %q", ipAndSlash, d)
  4067  					}
  4068  				} else if strings.HasPrefix(d, ipAndSlash) {
  4069  					t.Fatalf("Unexpected debug statement: %q", d)
  4070  				}
  4071  			case <-time.After(time.Second):
  4072  				t.Fatal("Did not get connect debug statement")
  4073  			}
  4074  		})
  4075  	}
  4076  }
  4077  
  4078  type partialWriteConn struct {
  4079  	net.Conn
  4080  }
  4081  
  4082  func (c *partialWriteConn) Write(b []byte) (int, error) {
  4083  	max := len(b)
  4084  	if max > 0 {
  4085  		max = rand.Intn(max)
  4086  		if max == 0 {
  4087  			max = 1
  4088  		}
  4089  	}
  4090  	n, err := c.Conn.Write(b[:max])
  4091  	if err == nil && max != len(b) {
  4092  		err = io.ErrShortWrite
  4093  	}
  4094  	return n, err
  4095  }
  4096  
  4097  func TestWSWithPartialWrite(t *testing.T) {
  4098  	conf := createConfFile(t, []byte(`
  4099  		listen: "127.0.0.1:-1"
  4100  		websocket {
  4101  			listen: "127.0.0.1:-1"
  4102  			no_tls: true
  4103  		}
  4104  	`))
  4105  	s, o := RunServerWithConfig(conf)
  4106  	defer s.Shutdown()
  4107  
  4108  	nc1 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o.Websocket.Port))
  4109  	defer nc1.Close()
  4110  
  4111  	sub := natsSubSync(t, nc1, "foo")
  4112  	sub.SetPendingLimits(-1, -1)
  4113  	natsFlush(t, nc1)
  4114  
  4115  	nc2 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o.Websocket.Port))
  4116  	defer nc2.Close()
  4117  
  4118  	// Replace websocket connections with ones that will produce short writes.
  4119  	s.mu.RLock()
  4120  	for _, c := range s.clients {
  4121  		c.mu.Lock()
  4122  		c.nc = &partialWriteConn{Conn: c.nc}
  4123  		c.mu.Unlock()
  4124  	}
  4125  	s.mu.RUnlock()
  4126  
  4127  	var msgs [][]byte
  4128  	for i := 0; i < 100; i++ {
  4129  		msg := make([]byte, rand.Intn(10000)+10)
  4130  		for j := 0; j < len(msg); j++ {
  4131  			msg[j] = byte('A' + j%26)
  4132  		}
  4133  		msgs = append(msgs, msg)
  4134  		natsPub(t, nc2, "foo", msg)
  4135  	}
  4136  	for i := 0; i < 100; i++ {
  4137  		rmsg := natsNexMsg(t, sub, time.Second)
  4138  		if !bytes.Equal(msgs[i], rmsg.Data) {
  4139  			t.Fatalf("Expected message %q, got %q", msgs[i], rmsg.Data)
  4140  		}
  4141  	}
  4142  }
  4143  
  4144  func testWSNoCorruptionWithFrameSizeLimit(t *testing.T, total int) {
  4145  	tmpl := `
  4146                 listen: "127.0.0.1:-1"
  4147                 cluster {
  4148                         name: "local"
  4149                         port: -1
  4150                         %s
  4151                 }
  4152                 websocket {
  4153                         listen: "127.0.0.1:-1"
  4154                         no_tls: true
  4155                 }
  4156         `
  4157  	conf1 := createConfFile(t, []byte(fmt.Sprintf(tmpl, _EMPTY_)))
  4158  	s1, o1 := RunServerWithConfig(conf1)
  4159  	defer s1.Shutdown()
  4160  
  4161  	routes := fmt.Sprintf("routes: [\"nats://127.0.0.1:%d\"]", o1.Cluster.Port)
  4162  	conf2 := createConfFile(t, []byte(fmt.Sprintf(tmpl, routes)))
  4163  	s2, o2 := RunServerWithConfig(conf2)
  4164  	defer s2.Shutdown()
  4165  
  4166  	conf3 := createConfFile(t, []byte(fmt.Sprintf(tmpl, routes)))
  4167  	s3, o3 := RunServerWithConfig(conf3)
  4168  	defer s3.Shutdown()
  4169  
  4170  	checkClusterFormed(t, s1, s2, s3)
  4171  
  4172  	nc3 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o3.Websocket.Port))
  4173  	defer nc3.Close()
  4174  
  4175  	nc2 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o2.Websocket.Port))
  4176  	defer nc2.Close()
  4177  
  4178  	payload := make([]byte, 100000)
  4179  	for i := 0; i < len(payload); i++ {
  4180  		payload[i] = 'A' + byte(i%26)
  4181  	}
  4182  	errCh := make(chan error, 1)
  4183  	doneCh := make(chan struct{}, 1)
  4184  	count := int32(0)
  4185  
  4186  	createSub := func(nc *nats.Conn) {
  4187  		sub := natsSub(t, nc, "foo", func(m *nats.Msg) {
  4188  			if !bytes.Equal(m.Data, payload) {
  4189  				stop := len(m.Data)
  4190  				if l := len(payload); l < stop {
  4191  					stop = l
  4192  				}
  4193  				start := 0
  4194  				for i := 0; i < stop; i++ {
  4195  					if m.Data[i] != payload[i] {
  4196  						start = i
  4197  						break
  4198  					}
  4199  				}
  4200  				if stop-start > 20 {
  4201  					stop = start + 20
  4202  				}
  4203  				select {
  4204  				case errCh <- fmt.Errorf("Invalid message: [%d bytes same]%s[...]", start, m.Data[start:stop]):
  4205  				default:
  4206  				}
  4207  				return
  4208  			}
  4209  			if n := atomic.AddInt32(&count, 1); int(n) == 2*total {
  4210  				doneCh <- struct{}{}
  4211  			}
  4212  		})
  4213  		sub.SetPendingLimits(-1, -1)
  4214  	}
  4215  	createSub(nc2)
  4216  	createSub(nc3)
  4217  
  4218  	checkSubInterest(t, s1, globalAccountName, "foo", time.Second)
  4219  
  4220  	nc1 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o1.Websocket.Port))
  4221  	defer nc1.Close()
  4222  	natsFlush(t, nc1)
  4223  
  4224  	// Change websocket connections to force a max frame size.
  4225  	for _, s := range []*Server{s1, s2, s3} {
  4226  		s.mu.RLock()
  4227  		for _, c := range s.clients {
  4228  			c.mu.Lock()
  4229  			if c.ws != nil {
  4230  				c.ws.browser = true
  4231  			}
  4232  			c.mu.Unlock()
  4233  		}
  4234  		s.mu.RUnlock()
  4235  	}
  4236  
  4237  	for i := 0; i < total; i++ {
  4238  		natsPub(t, nc1, "foo", payload)
  4239  		if i%100 == 0 {
  4240  			select {
  4241  			case err := <-errCh:
  4242  				t.Fatalf("Error: %v", err)
  4243  			default:
  4244  			}
  4245  		}
  4246  	}
  4247  	select {
  4248  	case err := <-errCh:
  4249  		t.Fatalf("Error: %v", err)
  4250  	case <-doneCh:
  4251  		return
  4252  	case <-time.After(10 * time.Second):
  4253  		t.Fatalf("Test timed out")
  4254  	}
  4255  }
  4256  
  4257  func TestWSNoCorruptionWithFrameSizeLimit(t *testing.T) {
  4258  	testWSNoCorruptionWithFrameSizeLimit(t, 1000)
  4259  }
  4260  
  4261  // ==================================================================
  4262  // = Benchmark tests
  4263  // ==================================================================
  4264  
  4265  const testWSBenchSubject = "a"
  4266  
  4267  var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()")
  4268  
  4269  func sizedString(sz int) string {
  4270  	b := make([]byte, sz)
  4271  	for i := range b {
  4272  		b[i] = ch[rand.Intn(len(ch))]
  4273  	}
  4274  	return string(b)
  4275  }
  4276  
  4277  func sizedStringForCompression(sz int) string {
  4278  	b := make([]byte, sz)
  4279  	c := byte(0)
  4280  	s := 0
  4281  	for i := range b {
  4282  		if s%20 == 0 {
  4283  			c = ch[rand.Intn(len(ch))]
  4284  		}
  4285  		b[i] = c
  4286  	}
  4287  	return string(b)
  4288  }
  4289  
  4290  func testWSFlushConn(b *testing.B, compress bool, c net.Conn, br *bufio.Reader) {
  4291  	buf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte(pingProto))
  4292  	c.Write(buf)
  4293  	c.SetReadDeadline(time.Now().Add(5 * time.Second))
  4294  	res := testWSReadFrame(b, br)
  4295  	c.SetReadDeadline(time.Time{})
  4296  	if !bytes.HasPrefix(res, []byte(pongProto)) {
  4297  		b.Fatalf("Failed read of PONG: %s\n", res)
  4298  	}
  4299  }
  4300  
  4301  func wsBenchPub(b *testing.B, numPubs int, compress bool, payload string) {
  4302  	b.StopTimer()
  4303  	opts := testWSOptions()
  4304  	opts.Websocket.Compression = compress
  4305  	s := RunServer(opts)
  4306  	defer s.Shutdown()
  4307  
  4308  	extra := 0
  4309  	pubProto := []byte(fmt.Sprintf("PUB %s %d\r\n%s\r\n", testWSBenchSubject, len(payload), payload))
  4310  	singleOpBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, pubProto)
  4311  
  4312  	// Simulate client that would buffer messages before framing/sending.
  4313  	// Figure out how many we can fit in one frame based on b.N and length of pubProto
  4314  	const bufSize = 32768
  4315  	tmpa := [bufSize]byte{}
  4316  	tmp := tmpa[:0]
  4317  	pb := 0
  4318  	for i := 0; i < b.N; i++ {
  4319  		tmp = append(tmp, pubProto...)
  4320  		pb++
  4321  		if len(tmp) >= bufSize {
  4322  			break
  4323  		}
  4324  	}
  4325  	sendBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, tmp)
  4326  	n := b.N / pb
  4327  	extra = b.N - (n * pb)
  4328  
  4329  	wg := sync.WaitGroup{}
  4330  	wg.Add(numPubs)
  4331  
  4332  	type pub struct {
  4333  		c  net.Conn
  4334  		br *bufio.Reader
  4335  		bw *bufio.Writer
  4336  	}
  4337  	var pubs []pub
  4338  	for i := 0; i < numPubs; i++ {
  4339  		wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port)
  4340  		defer wsc.Close()
  4341  		bw := bufio.NewWriterSize(wsc, bufSize)
  4342  		pubs = append(pubs, pub{wsc, br, bw})
  4343  	}
  4344  
  4345  	// Average the amount of bytes sent by iteration
  4346  	avg := len(sendBuf) / pb
  4347  	if extra > 0 {
  4348  		avg += len(singleOpBuf)
  4349  		avg /= 2
  4350  	}
  4351  	b.SetBytes(int64(numPubs * avg))
  4352  	b.StartTimer()
  4353  
  4354  	for i := 0; i < numPubs; i++ {
  4355  		p := pubs[i]
  4356  		go func(p pub) {
  4357  			defer wg.Done()
  4358  			for i := 0; i < n; i++ {
  4359  				p.bw.Write(sendBuf)
  4360  			}
  4361  			for i := 0; i < extra; i++ {
  4362  				p.bw.Write(singleOpBuf)
  4363  			}
  4364  			p.bw.Flush()
  4365  			testWSFlushConn(b, compress, p.c, p.br)
  4366  		}(p)
  4367  	}
  4368  	wg.Wait()
  4369  	b.StopTimer()
  4370  }
  4371  
  4372  func Benchmark_WS_Pubx1_CN_____0b(b *testing.B) {
  4373  	wsBenchPub(b, 1, false, "")
  4374  }
  4375  
  4376  func Benchmark_WS_Pubx1_CY_____0b(b *testing.B) {
  4377  	wsBenchPub(b, 1, true, "")
  4378  }
  4379  
  4380  func Benchmark_WS_Pubx1_CN___128b(b *testing.B) {
  4381  	s := sizedString(128)
  4382  	wsBenchPub(b, 1, false, s)
  4383  }
  4384  
  4385  func Benchmark_WS_Pubx1_CY___128b(b *testing.B) {
  4386  	s := sizedStringForCompression(128)
  4387  	wsBenchPub(b, 1, true, s)
  4388  }
  4389  
  4390  func Benchmark_WS_Pubx1_CN__1024b(b *testing.B) {
  4391  	s := sizedString(1024)
  4392  	wsBenchPub(b, 1, false, s)
  4393  }
  4394  
  4395  func Benchmark_WS_Pubx1_CY__1024b(b *testing.B) {
  4396  	s := sizedStringForCompression(1024)
  4397  	wsBenchPub(b, 1, true, s)
  4398  }
  4399  
  4400  func Benchmark_WS_Pubx1_CN__4096b(b *testing.B) {
  4401  	s := sizedString(4 * 1024)
  4402  	wsBenchPub(b, 1, false, s)
  4403  }
  4404  
  4405  func Benchmark_WS_Pubx1_CY__4096b(b *testing.B) {
  4406  	s := sizedStringForCompression(4 * 1024)
  4407  	wsBenchPub(b, 1, true, s)
  4408  }
  4409  
  4410  func Benchmark_WS_Pubx1_CN__8192b(b *testing.B) {
  4411  	s := sizedString(8 * 1024)
  4412  	wsBenchPub(b, 1, false, s)
  4413  }
  4414  
  4415  func Benchmark_WS_Pubx1_CY__8192b(b *testing.B) {
  4416  	s := sizedStringForCompression(8 * 1024)
  4417  	wsBenchPub(b, 1, true, s)
  4418  }
  4419  
  4420  func Benchmark_WS_Pubx1_CN_32768b(b *testing.B) {
  4421  	s := sizedString(32 * 1024)
  4422  	wsBenchPub(b, 1, false, s)
  4423  }
  4424  
  4425  func Benchmark_WS_Pubx1_CY_32768b(b *testing.B) {
  4426  	s := sizedStringForCompression(32 * 1024)
  4427  	wsBenchPub(b, 1, true, s)
  4428  }
  4429  
  4430  func Benchmark_WS_Pubx5_CN_____0b(b *testing.B) {
  4431  	wsBenchPub(b, 5, false, "")
  4432  }
  4433  
  4434  func Benchmark_WS_Pubx5_CY_____0b(b *testing.B) {
  4435  	wsBenchPub(b, 5, true, "")
  4436  }
  4437  
  4438  func Benchmark_WS_Pubx5_CN___128b(b *testing.B) {
  4439  	s := sizedString(128)
  4440  	wsBenchPub(b, 5, false, s)
  4441  }
  4442  
  4443  func Benchmark_WS_Pubx5_CY___128b(b *testing.B) {
  4444  	s := sizedStringForCompression(128)
  4445  	wsBenchPub(b, 5, true, s)
  4446  }
  4447  
  4448  func Benchmark_WS_Pubx5_CN__1024b(b *testing.B) {
  4449  	s := sizedString(1024)
  4450  	wsBenchPub(b, 5, false, s)
  4451  }
  4452  
  4453  func Benchmark_WS_Pubx5_CY__1024b(b *testing.B) {
  4454  	s := sizedStringForCompression(1024)
  4455  	wsBenchPub(b, 5, true, s)
  4456  }
  4457  
  4458  func Benchmark_WS_Pubx5_CN__4096b(b *testing.B) {
  4459  	s := sizedString(4 * 1024)
  4460  	wsBenchPub(b, 5, false, s)
  4461  }
  4462  
  4463  func Benchmark_WS_Pubx5_CY__4096b(b *testing.B) {
  4464  	s := sizedStringForCompression(4 * 1024)
  4465  	wsBenchPub(b, 5, true, s)
  4466  }
  4467  
  4468  func Benchmark_WS_Pubx5_CN__8192b(b *testing.B) {
  4469  	s := sizedString(8 * 1024)
  4470  	wsBenchPub(b, 5, false, s)
  4471  }
  4472  
  4473  func Benchmark_WS_Pubx5_CY__8192b(b *testing.B) {
  4474  	s := sizedStringForCompression(8 * 1024)
  4475  	wsBenchPub(b, 5, true, s)
  4476  }
  4477  
  4478  func Benchmark_WS_Pubx5_CN_32768b(b *testing.B) {
  4479  	s := sizedString(32 * 1024)
  4480  	wsBenchPub(b, 5, false, s)
  4481  }
  4482  
  4483  func Benchmark_WS_Pubx5_CY_32768b(b *testing.B) {
  4484  	s := sizedStringForCompression(32 * 1024)
  4485  	wsBenchPub(b, 5, true, s)
  4486  }
  4487  
  4488  func wsBenchSub(b *testing.B, numSubs int, compress bool, payload string) {
  4489  	b.StopTimer()
  4490  	opts := testWSOptions()
  4491  	opts.Websocket.Compression = compress
  4492  	s := RunServer(opts)
  4493  	defer s.Shutdown()
  4494  
  4495  	var subs []*bufio.Reader
  4496  	for i := 0; i < numSubs; i++ {
  4497  		wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port)
  4498  		defer wsc.Close()
  4499  		subProto := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress,
  4500  			[]byte(fmt.Sprintf("SUB %s 1\r\nPING\r\n", testWSBenchSubject)))
  4501  		wsc.Write(subProto)
  4502  		// Waiting for PONG
  4503  		testWSReadFrame(b, br)
  4504  		subs = append(subs, br)
  4505  	}
  4506  
  4507  	wg := sync.WaitGroup{}
  4508  	wg.Add(numSubs)
  4509  
  4510  	// Use regular NATS client to publish messages
  4511  	nc := natsConnect(b, s.ClientURL())
  4512  	defer nc.Close()
  4513  
  4514  	b.StartTimer()
  4515  
  4516  	for i := 0; i < numSubs; i++ {
  4517  		br := subs[i]
  4518  		go func(br *bufio.Reader) {
  4519  			defer wg.Done()
  4520  			for count := 0; count < b.N; {
  4521  				msgs := testWSReadFrame(b, br)
  4522  				count += bytes.Count(msgs, []byte("MSG "))
  4523  			}
  4524  		}(br)
  4525  	}
  4526  	for i := 0; i < b.N; i++ {
  4527  		natsPub(b, nc, testWSBenchSubject, []byte(payload))
  4528  	}
  4529  	wg.Wait()
  4530  	b.StopTimer()
  4531  }
  4532  
  4533  func Benchmark_WS_Subx1_CN_____0b(b *testing.B) {
  4534  	wsBenchSub(b, 1, false, "")
  4535  }
  4536  
  4537  func Benchmark_WS_Subx1_CY_____0b(b *testing.B) {
  4538  	wsBenchSub(b, 1, true, "")
  4539  }
  4540  
  4541  func Benchmark_WS_Subx1_CN___128b(b *testing.B) {
  4542  	s := sizedString(128)
  4543  	wsBenchSub(b, 1, false, s)
  4544  }
  4545  
  4546  func Benchmark_WS_Subx1_CY___128b(b *testing.B) {
  4547  	s := sizedStringForCompression(128)
  4548  	wsBenchSub(b, 1, true, s)
  4549  }
  4550  
  4551  func Benchmark_WS_Subx1_CN__1024b(b *testing.B) {
  4552  	s := sizedString(1024)
  4553  	wsBenchSub(b, 1, false, s)
  4554  }
  4555  
  4556  func Benchmark_WS_Subx1_CY__1024b(b *testing.B) {
  4557  	s := sizedStringForCompression(1024)
  4558  	wsBenchSub(b, 1, true, s)
  4559  }
  4560  
  4561  func Benchmark_WS_Subx1_CN__4096b(b *testing.B) {
  4562  	s := sizedString(4096)
  4563  	wsBenchSub(b, 1, false, s)
  4564  }
  4565  
  4566  func Benchmark_WS_Subx1_CY__4096b(b *testing.B) {
  4567  	s := sizedStringForCompression(4096)
  4568  	wsBenchSub(b, 1, true, s)
  4569  }
  4570  
  4571  func Benchmark_WS_Subx1_CN__8192b(b *testing.B) {
  4572  	s := sizedString(8192)
  4573  	wsBenchSub(b, 1, false, s)
  4574  }
  4575  
  4576  func Benchmark_WS_Subx1_CY__8192b(b *testing.B) {
  4577  	s := sizedStringForCompression(8192)
  4578  	wsBenchSub(b, 1, true, s)
  4579  }
  4580  
  4581  func Benchmark_WS_Subx1_CN_32768b(b *testing.B) {
  4582  	s := sizedString(32768)
  4583  	wsBenchSub(b, 1, false, s)
  4584  }
  4585  
  4586  func Benchmark_WS_Subx1_CY_32768b(b *testing.B) {
  4587  	s := sizedStringForCompression(32768)
  4588  	wsBenchSub(b, 1, true, s)
  4589  }
  4590  
  4591  func Benchmark_WS_Subx5_CN_____0b(b *testing.B) {
  4592  	wsBenchSub(b, 5, false, "")
  4593  }
  4594  
  4595  func Benchmark_WS_Subx5_CY_____0b(b *testing.B) {
  4596  	wsBenchSub(b, 5, true, "")
  4597  }
  4598  
  4599  func Benchmark_WS_Subx5_CN___128b(b *testing.B) {
  4600  	s := sizedString(128)
  4601  	wsBenchSub(b, 5, false, s)
  4602  }
  4603  
  4604  func Benchmark_WS_Subx5_CY___128b(b *testing.B) {
  4605  	s := sizedStringForCompression(128)
  4606  	wsBenchSub(b, 5, true, s)
  4607  }
  4608  
  4609  func Benchmark_WS_Subx5_CN__1024b(b *testing.B) {
  4610  	s := sizedString(1024)
  4611  	wsBenchSub(b, 5, false, s)
  4612  }
  4613  
  4614  func Benchmark_WS_Subx5_CY__1024b(b *testing.B) {
  4615  	s := sizedStringForCompression(1024)
  4616  	wsBenchSub(b, 5, true, s)
  4617  }
  4618  
  4619  func Benchmark_WS_Subx5_CN__4096b(b *testing.B) {
  4620  	s := sizedString(4096)
  4621  	wsBenchSub(b, 5, false, s)
  4622  }
  4623  
  4624  func Benchmark_WS_Subx5_CY__4096b(b *testing.B) {
  4625  	s := sizedStringForCompression(4096)
  4626  	wsBenchSub(b, 5, true, s)
  4627  }
  4628  
  4629  func Benchmark_WS_Subx5_CN__8192b(b *testing.B) {
  4630  	s := sizedString(8192)
  4631  	wsBenchSub(b, 5, false, s)
  4632  }
  4633  
  4634  func Benchmark_WS_Subx5_CY__8192b(b *testing.B) {
  4635  	s := sizedStringForCompression(8192)
  4636  	wsBenchSub(b, 5, true, s)
  4637  }
  4638  
  4639  func Benchmark_WS_Subx5_CN_32768b(b *testing.B) {
  4640  	s := sizedString(32768)
  4641  	wsBenchSub(b, 5, false, s)
  4642  }
  4643  
  4644  func Benchmark_WS_Subx5_CY_32768b(b *testing.B) {
  4645  	s := sizedStringForCompression(32768)
  4646  	wsBenchSub(b, 5, true, s)
  4647  }