get.pme.sh/pnats@v0.0.0-20240304004023-26bb5a137ed0/server/websocket_test.go (about) 1 // Copyright 2020-2024 The NATS Authors 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package server 15 16 import ( 17 "bufio" 18 "bytes" 19 "crypto/tls" 20 "encoding/base64" 21 "encoding/binary" 22 "encoding/json" 23 "errors" 24 "fmt" 25 "io" 26 "math/rand" 27 "net" 28 "net/http" 29 "net/url" 30 "reflect" 31 "strings" 32 "sync" 33 "sync/atomic" 34 "testing" 35 "time" 36 37 "github.com/nats-io/jwt/v2" 38 "github.com/nats-io/nats.go" 39 "github.com/nats-io/nkeys" 40 41 "github.com/klauspost/compress/flate" 42 ) 43 44 type testReader struct { 45 buf []byte 46 pos int 47 max int 48 err error 49 } 50 51 func (tr *testReader) Read(p []byte) (int, error) { 52 if tr.err != nil { 53 return 0, tr.err 54 } 55 n := len(tr.buf) - tr.pos 56 if n == 0 { 57 return 0, nil 58 } 59 if n > len(p) { 60 n = len(p) 61 } 62 if tr.max > 0 && n > tr.max { 63 n = tr.max 64 } 65 copy(p, tr.buf[tr.pos:tr.pos+n]) 66 tr.pos += n 67 return n, nil 68 } 69 70 func TestWSGet(t *testing.T) { 71 rb := []byte("012345") 72 73 tr := &testReader{buf: []byte("6789")} 74 75 for _, test := range []struct { 76 name string 77 pos int 78 needed int 79 newpos int 80 trmax int 81 result string 82 reterr bool 83 }{ 84 {"fromrb1", 0, 3, 3, 4, "012", false}, // Partial from read buffer 85 {"fromrb2", 3, 2, 5, 4, "34", false}, // Partial from read buffer 86 {"fromrb3", 5, 1, 6, 4, "5", false}, // Partial from read buffer 87 {"fromtr1", 4, 4, 6, 4, "4567", false}, // Partial from read buffer + some of ioReader 88 {"fromtr2", 4, 6, 6, 4, "456789", false}, // Partial from read buffer + all of ioReader 89 {"fromtr3", 4, 6, 6, 2, "456789", false}, // Partial from read buffer + all of ioReader with several reads 90 {"fromtr4", 4, 6, 6, 2, "", true}, // ioReader returns error 91 } { 92 t.Run(test.name, func(t *testing.T) { 93 tr.pos = 0 94 tr.max = test.trmax 95 if test.reterr { 96 tr.err = fmt.Errorf("on purpose") 97 } 98 res, np, err := wsGet(tr, rb, test.pos, test.needed) 99 if test.reterr { 100 if err == nil { 101 t.Fatalf("Expected error, got none") 102 } 103 if err.Error() != "on purpose" { 104 t.Fatalf("Unexpected error: %v", err) 105 } 106 if np != 0 || res != nil { 107 t.Fatalf("Unexpected returned values: res=%v n=%v", res, np) 108 } 109 return 110 } 111 if err != nil { 112 t.Fatalf("Error on get: %v", err) 113 } 114 if np != test.newpos { 115 t.Fatalf("Expected pos=%v, got %v", test.newpos, np) 116 } 117 if string(res) != test.result { 118 t.Fatalf("Invalid returned content: %s", res) 119 } 120 }) 121 } 122 } 123 124 func TestWSIsControlFrame(t *testing.T) { 125 for _, test := range []struct { 126 name string 127 code wsOpCode 128 isControl bool 129 }{ 130 {"binary", wsBinaryMessage, false}, 131 {"text", wsTextMessage, false}, 132 {"ping", wsPingMessage, true}, 133 {"pong", wsPongMessage, true}, 134 {"close", wsCloseMessage, true}, 135 } { 136 t.Run(test.name, func(t *testing.T) { 137 if res := wsIsControlFrame(test.code); res != test.isControl { 138 t.Fatalf("Expected %q isControl to be %v, got %v", test.name, test.isControl, res) 139 } 140 }) 141 } 142 } 143 144 func testWSSimpleMask(key, buf []byte) { 145 for i := 0; i < len(buf); i++ { 146 buf[i] ^= key[i&3] 147 } 148 } 149 150 func TestWSUnmask(t *testing.T) { 151 key := []byte{1, 2, 3, 4} 152 orgBuf := []byte("this is a clear text") 153 154 mask := func() []byte { 155 t.Helper() 156 buf := append([]byte(nil), orgBuf...) 157 testWSSimpleMask(key, buf) 158 // First ensure that the content is masked. 159 if bytes.Equal(buf, orgBuf) { 160 t.Fatalf("Masking did not do anything: %q", buf) 161 } 162 return buf 163 } 164 165 ri := &wsReadInfo{mask: true} 166 ri.init() 167 copy(ri.mkey[:], key) 168 169 buf := mask() 170 // Unmask in one call 171 ri.unmask(buf) 172 if !bytes.Equal(buf, orgBuf) { 173 t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf) 174 } 175 176 // Unmask in multiple calls 177 buf = mask() 178 ri.mkpos = 0 179 ri.unmask(buf[:3]) 180 ri.unmask(buf[3:11]) 181 ri.unmask(buf[11:]) 182 if !bytes.Equal(buf, orgBuf) { 183 t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf) 184 } 185 } 186 187 func TestWSCreateCloseMessage(t *testing.T) { 188 for _, test := range []struct { 189 name string 190 status int 191 psize int 192 truncated bool 193 }{ 194 {"fits", wsCloseStatusInternalSrvError, 10, false}, 195 {"truncated", wsCloseStatusProtocolError, wsMaxControlPayloadSize + 10, true}, 196 } { 197 t.Run(test.name, func(t *testing.T) { 198 payload := make([]byte, test.psize) 199 for i := 0; i < len(payload); i++ { 200 payload[i] = byte('A' + (i % 26)) 201 } 202 res := wsCreateCloseMessage(test.status, string(payload)) 203 if status := binary.BigEndian.Uint16(res[:2]); int(status) != test.status { 204 t.Fatalf("Expected status to be %v, got %v", test.status, status) 205 } 206 psize := len(res) - 2 207 if !test.truncated { 208 if int(psize) != test.psize { 209 t.Fatalf("Expected size to be %v, got %v", test.psize, psize) 210 } 211 if !bytes.Equal(res[2:], payload) { 212 t.Fatalf("Unexpected result: %q", res[2:]) 213 } 214 return 215 } 216 // Since the payload of a close message contains a 2 byte status, the 217 // actual max text size will be wsMaxControlPayloadSize-2 218 if int(psize) != wsMaxControlPayloadSize-2 { 219 t.Fatalf("Expected size to be capped to %v, got %v", wsMaxControlPayloadSize-2, psize) 220 } 221 if string(res[len(res)-3:]) != "..." { 222 t.Fatalf("Expected res to have `...` at the end, got %q", res[4:]) 223 } 224 }) 225 } 226 } 227 228 func TestWSCreateFrameHeader(t *testing.T) { 229 for _, test := range []struct { 230 name string 231 frameType wsOpCode 232 compressed bool 233 len int 234 }{ 235 {"uncompressed 10", wsBinaryMessage, false, 10}, 236 {"uncompressed 600", wsTextMessage, false, 600}, 237 {"uncompressed 100000", wsTextMessage, false, 100000}, 238 {"compressed 10", wsBinaryMessage, true, 10}, 239 {"compressed 600", wsBinaryMessage, true, 600}, 240 {"compressed 100000", wsTextMessage, true, 100000}, 241 } { 242 t.Run(test.name, func(t *testing.T) { 243 res, _ := wsCreateFrameHeader(false, test.compressed, test.frameType, test.len) 244 // The server is always sending the message has a single frame, 245 // so the "final" bit should be set. 246 expected := byte(test.frameType) | wsFinalBit 247 if test.compressed { 248 expected |= wsRsv1Bit 249 } 250 if b := res[0]; b != expected { 251 t.Fatalf("Expected first byte to be %v, got %v", expected, b) 252 } 253 switch { 254 case test.len <= 125: 255 if len(res) != 2 { 256 t.Fatalf("Frame len should be 2, got %v", len(res)) 257 } 258 if res[1] != byte(test.len) { 259 t.Fatalf("Expected len to be in second byte and be %v, got %v", test.len, res[1]) 260 } 261 case test.len < 65536: 262 // 1+1+2 263 if len(res) != 4 { 264 t.Fatalf("Frame len should be 4, got %v", len(res)) 265 } 266 if res[1] != 126 { 267 t.Fatalf("Second byte value should be 126, got %v", res[1]) 268 } 269 if rl := binary.BigEndian.Uint16(res[2:]); int(rl) != test.len { 270 t.Fatalf("Expected len to be %v, got %v", test.len, rl) 271 } 272 default: 273 // 1+1+8 274 if len(res) != 10 { 275 t.Fatalf("Frame len should be 10, got %v", len(res)) 276 } 277 if res[1] != 127 { 278 t.Fatalf("Second byte value should be 127, got %v", res[1]) 279 } 280 if rl := binary.BigEndian.Uint64(res[2:]); int(rl) != test.len { 281 t.Fatalf("Expected len to be %v, got %v", test.len, rl) 282 } 283 } 284 }) 285 } 286 } 287 288 func testWSCreateClientMsg(frameType wsOpCode, frameNum int, final, compressed bool, payload []byte) []byte { 289 if compressed { 290 buf := &bytes.Buffer{} 291 compressor, _ := flate.NewWriter(buf, 1) 292 compressor.Write(payload) 293 compressor.Flush() 294 payload = buf.Bytes() 295 // The last 4 bytes are dropped 296 payload = payload[:len(payload)-4] 297 } 298 frame := make([]byte, 14+len(payload)) 299 if frameNum == 1 { 300 frame[0] = byte(frameType) 301 } 302 if final { 303 frame[0] |= wsFinalBit 304 } 305 if compressed { 306 frame[0] |= wsRsv1Bit 307 } 308 pos := 1 309 lenPayload := len(payload) 310 switch { 311 case lenPayload <= 125: 312 frame[pos] = byte(lenPayload) | wsMaskBit 313 pos++ 314 case lenPayload < 65536: 315 frame[pos] = 126 | wsMaskBit 316 binary.BigEndian.PutUint16(frame[2:], uint16(lenPayload)) 317 pos += 3 318 default: 319 frame[1] = 127 | wsMaskBit 320 binary.BigEndian.PutUint64(frame[2:], uint64(lenPayload)) 321 pos += 9 322 } 323 key := []byte{1, 2, 3, 4} 324 copy(frame[pos:], key) 325 pos += 4 326 copy(frame[pos:], payload) 327 testWSSimpleMask(key, frame[pos:]) 328 pos += lenPayload 329 return frame[:pos] 330 } 331 332 func testWSSetupForRead() (*client, *wsReadInfo, *testReader) { 333 ri := &wsReadInfo{mask: true} 334 ri.init() 335 tr := &testReader{} 336 opts := DefaultOptions() 337 opts.MaxPending = MAX_PENDING_SIZE 338 s := &Server{opts: opts} 339 c := &client{srv: s, ws: &websocket{}} 340 c.initClient() 341 return c, ri, tr 342 } 343 344 func TestWSReadUncompressedFrames(t *testing.T) { 345 c, ri, tr := testWSSetupForRead() 346 // Create 2 WS messages 347 pl1 := []byte("first message") 348 wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl1) 349 pl2 := []byte("second message") 350 wsmsg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl2) 351 // Add both in single buffer 352 orgrb := append([]byte(nil), wsmsg1...) 353 orgrb = append(orgrb, wsmsg2...) 354 355 rb := append([]byte(nil), orgrb...) 356 bufs, err := c.wsRead(ri, tr, rb) 357 if err != nil { 358 t.Fatalf("Unexpected error: %v", err) 359 } 360 if n := len(bufs); n != 2 { 361 t.Fatalf("Expected 2 buffers, got %v", n) 362 } 363 if !bytes.Equal(bufs[0], pl1) { 364 t.Fatalf("Unexpected content for buffer 1: %s", bufs[0]) 365 } 366 if !bytes.Equal(bufs[1], pl2) { 367 t.Fatalf("Unexpected content for buffer 2: %s", bufs[1]) 368 } 369 370 // Now reset and try with the read buffer not containing full ws frame 371 c, ri, tr = testWSSetupForRead() 372 rb = append([]byte(nil), orgrb...) 373 // Frame is 1+1+4+'first message'. So say we pass with rb of 11 bytes, 374 // then we should get "first" 375 bufs, err = c.wsRead(ri, tr, rb[:11]) 376 if err != nil { 377 t.Fatalf("Unexpected error: %v", err) 378 } 379 if n := len(bufs); n != 1 { 380 t.Fatalf("Unexpected buffer returned: %v", n) 381 } 382 if string(bufs[0]) != "first" { 383 t.Fatalf("Unexpected content: %q", bufs[0]) 384 } 385 // Call again with more data.. 386 bufs, err = c.wsRead(ri, tr, rb[11:32]) 387 if err != nil { 388 t.Fatalf("Unexpected error: %v", err) 389 } 390 if n := len(bufs); n != 2 { 391 t.Fatalf("Unexpected buffer returned: %v", n) 392 } 393 if string(bufs[0]) != " message" { 394 t.Fatalf("Unexpected content: %q", bufs[0]) 395 } 396 if string(bufs[1]) != "second " { 397 t.Fatalf("Unexpected content: %q", bufs[1]) 398 } 399 // Call with the rest 400 bufs, err = c.wsRead(ri, tr, rb[32:]) 401 if err != nil { 402 t.Fatalf("Unexpected error: %v", err) 403 } 404 if n := len(bufs); n != 1 { 405 t.Fatalf("Unexpected buffer returned: %v", n) 406 } 407 if string(bufs[0]) != "message" { 408 t.Fatalf("Unexpected content: %q", bufs[0]) 409 } 410 } 411 412 func TestWSReadCompressedFrames(t *testing.T) { 413 c, ri, tr := testWSSetupForRead() 414 uncompressed := []byte("this is the uncompress data") 415 wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed) 416 rb := append([]byte(nil), wsmsg1...) 417 // Call with some but not all of the payload 418 bufs, err := c.wsRead(ri, tr, rb[:10]) 419 if err != nil { 420 t.Fatalf("Unexpected error: %v", err) 421 } 422 if n := len(bufs); n != 0 { 423 t.Fatalf("Unexpected buffer returned: %v", n) 424 } 425 // Call with the rest, only then should we get the uncompressed data. 426 bufs, err = c.wsRead(ri, tr, rb[10:]) 427 if err != nil { 428 t.Fatalf("Unexpected error: %v", err) 429 } 430 if n := len(bufs); n != 1 { 431 t.Fatalf("Unexpected buffer returned: %v", n) 432 } 433 if !bytes.Equal(bufs[0], uncompressed) { 434 t.Fatalf("Unexpected content: %s", bufs[0]) 435 } 436 // Stress the fact that we use a pool and want to make sure 437 // that if we get a decompressor from the pool, it is properly reset 438 // with the buffer to decompress. 439 // Since we unmask the read buffer, reset it now and fill it 440 // with 10 compressed frames. 441 rb = nil 442 for i := 0; i < 10; i++ { 443 rb = append(rb, wsmsg1...) 444 } 445 bufs, err = c.wsRead(ri, tr, rb) 446 if err != nil { 447 t.Fatalf("Unexpected error: %v", err) 448 } 449 if n := len(bufs); n != 10 { 450 t.Fatalf("Unexpected buffer returned: %v", n) 451 } 452 453 // Compress a message and send it in several frames. 454 buf := &bytes.Buffer{} 455 compressor, _ := flate.NewWriter(buf, 1) 456 compressor.Write(uncompressed) 457 compressor.Flush() 458 compressed := buf.Bytes() 459 // The last 4 bytes are dropped 460 compressed = compressed[:len(compressed)-4] 461 ncomp := 10 462 frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, compressed[:ncomp]) 463 frag1[0] |= wsRsv1Bit 464 frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, compressed[ncomp:]) 465 rb = append([]byte(nil), frag1...) 466 rb = append(rb, frag2...) 467 bufs, err = c.wsRead(ri, tr, rb) 468 if err != nil { 469 t.Fatalf("Unexpected error: %v", err) 470 } 471 if n := len(bufs); n != 1 { 472 t.Fatalf("Unexpected buffer returned: %v", n) 473 } 474 if !bytes.Equal(bufs[0], uncompressed) { 475 t.Fatalf("Unexpected content: %s", bufs[0]) 476 } 477 } 478 479 func TestWSReadCompressedFrameCorrupted(t *testing.T) { 480 c, ri, tr := testWSSetupForRead() 481 uncompressed := []byte("this is the uncompress data") 482 wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed) 483 copy(wsmsg1[10:], []byte{1, 2, 3, 4}) 484 rb := append([]byte(nil), wsmsg1...) 485 bufs, err := c.wsRead(ri, tr, rb) 486 if err == nil || !strings.Contains(err.Error(), "corrupt") { 487 t.Fatalf("Expected error about corrupted data, got %v", err) 488 } 489 if n := len(bufs); n != 0 { 490 t.Fatalf("Expected no buffer, got %v", n) 491 } 492 } 493 494 func TestWSReadVariousFrameSizes(t *testing.T) { 495 for _, test := range []struct { 496 name string 497 size int 498 }{ 499 {"tiny", 100}, 500 {"medium", 1000}, 501 {"large", 70000}, 502 } { 503 t.Run(test.name, func(t *testing.T) { 504 c, ri, tr := testWSSetupForRead() 505 uncompressed := make([]byte, test.size) 506 for i := 0; i < len(uncompressed); i++ { 507 uncompressed[i] = 'A' + byte(i%26) 508 } 509 wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, uncompressed) 510 rb := append([]byte(nil), wsmsg1...) 511 bufs, err := c.wsRead(ri, tr, rb) 512 if err != nil { 513 t.Fatalf("Unexpected error: %v", err) 514 } 515 if n := len(bufs); n != 1 { 516 t.Fatalf("Unexpected buffer returned: %v", n) 517 } 518 if !bytes.Equal(bufs[0], uncompressed) { 519 t.Fatalf("Unexpected content: %s", bufs[0]) 520 } 521 }) 522 } 523 } 524 525 func TestWSReadFragmentedFrames(t *testing.T) { 526 c, ri, tr := testWSSetupForRead() 527 payloads := []string{"first", "second", "third"} 528 var rb []byte 529 for i := 0; i < len(payloads); i++ { 530 final := i == len(payloads)-1 531 frag := testWSCreateClientMsg(wsBinaryMessage, i+1, final, false, []byte(payloads[i])) 532 rb = append(rb, frag...) 533 } 534 bufs, err := c.wsRead(ri, tr, rb) 535 if err != nil { 536 t.Fatalf("Unexpected error: %v", err) 537 } 538 if n := len(bufs); n != 3 { 539 t.Fatalf("Unexpected buffer returned: %v", n) 540 } 541 for i, expected := range payloads { 542 if string(bufs[i]) != expected { 543 t.Fatalf("Unexpected content for buf=%v: %s", i, bufs[i]) 544 } 545 } 546 } 547 548 func TestWSReadPartialFrameHeaderAtEndOfReadBuffer(t *testing.T) { 549 c, ri, tr := testWSSetupForRead() 550 msg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg1")) 551 msg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg2")) 552 rb := append([]byte(nil), msg1...) 553 rb = append(rb, msg2...) 554 // We will pass the first frame + the first byte of the next frame. 555 rbl := rb[:len(msg1)+1] 556 // Make the io reader return the rest of the frame 557 tr.buf = rb[len(msg1)+1:] 558 bufs, err := c.wsRead(ri, tr, rbl) 559 if err != nil { 560 t.Fatalf("Unexpected error: %v", err) 561 } 562 if n := len(bufs); n != 1 { 563 t.Fatalf("Unexpected buffer returned: %v", n) 564 } 565 // We should not have asked to the io reader more than what is needed for reading 566 // the frame header. Since we had already the first byte in the read buffer, 567 // tr.pos should be 1(size)+4(key)=5 568 if tr.pos != 5 { 569 t.Fatalf("Expected reader pos to be 5, got %v", tr.pos) 570 } 571 } 572 573 func TestWSReadPingFrame(t *testing.T) { 574 for _, test := range []struct { 575 name string 576 payload []byte 577 }{ 578 {"without payload", nil}, 579 {"with payload", []byte("optional payload")}, 580 } { 581 t.Run(test.name, func(t *testing.T) { 582 c, ri, tr := testWSSetupForRead() 583 ping := testWSCreateClientMsg(wsPingMessage, 1, true, false, test.payload) 584 rb := append([]byte(nil), ping...) 585 bufs, err := c.wsRead(ri, tr, rb) 586 if err != nil { 587 t.Fatalf("Unexpected error: %v", err) 588 } 589 if n := len(bufs); n != 0 { 590 t.Fatalf("Unexpected buffer returned: %v", n) 591 } 592 // A PONG should have been queued with the payload of the ping 593 c.mu.Lock() 594 nb, _ := c.collapsePtoNB() 595 c.mu.Unlock() 596 if n := len(nb); n == 0 { 597 t.Fatalf("Expected buffers, got %v", n) 598 } 599 if expected := 2 + len(test.payload); expected != len(nb[0]) { 600 t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) 601 } 602 b := nb[0][0] 603 if b&wsFinalBit == 0 { 604 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 605 } 606 if b&byte(wsPongMessage) == 0 { 607 t.Fatalf("Should have been a PONG, it wasn't: %v", b) 608 } 609 if len(test.payload) > 0 { 610 if !bytes.Equal(nb[0][2:], test.payload) { 611 t.Fatalf("Unexpected content: %s", nb[0][2:]) 612 } 613 } 614 }) 615 } 616 } 617 618 func TestWSReadPongFrame(t *testing.T) { 619 for _, test := range []struct { 620 name string 621 payload []byte 622 }{ 623 {"without payload", nil}, 624 {"with payload", []byte("optional payload")}, 625 } { 626 t.Run(test.name, func(t *testing.T) { 627 c, ri, tr := testWSSetupForRead() 628 pong := testWSCreateClientMsg(wsPongMessage, 1, true, false, test.payload) 629 rb := append([]byte(nil), pong...) 630 bufs, err := c.wsRead(ri, tr, rb) 631 if err != nil { 632 t.Fatalf("Unexpected error: %v", err) 633 } 634 if n := len(bufs); n != 0 { 635 t.Fatalf("Unexpected buffer returned: %v", n) 636 } 637 // Nothing should be sent... 638 c.mu.Lock() 639 nb, _ := c.collapsePtoNB() 640 c.mu.Unlock() 641 if n := len(nb); n != 0 { 642 t.Fatalf("Expected no buffer, got %v", n) 643 } 644 }) 645 } 646 } 647 648 func TestWSReadCloseFrame(t *testing.T) { 649 for _, test := range []struct { 650 name string 651 payload []byte 652 }{ 653 {"without payload", nil}, 654 {"with payload", []byte("optional payload")}, 655 } { 656 t.Run(test.name, func(t *testing.T) { 657 c, ri, tr := testWSSetupForRead() 658 // a close message has a status in 2 bytes + optional payload 659 payload := make([]byte, 2+len(test.payload)) 660 binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure) 661 if len(test.payload) > 0 { 662 copy(payload[2:], test.payload) 663 } 664 close := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload) 665 // Have a normal frame prior to close to make sure that wsRead returns 666 // the normal frame along with io.EOF to indicate that wsCloseMessage was received. 667 msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg")) 668 rb := append([]byte(nil), msg...) 669 rb = append(rb, close...) 670 bufs, err := c.wsRead(ri, tr, rb) 671 // It is expected that wsRead returns io.EOF on processing a close. 672 if err != io.EOF { 673 t.Fatalf("Unexpected error: %v", err) 674 } 675 if n := len(bufs); n != 1 { 676 t.Fatalf("Unexpected buffer returned: %v", n) 677 } 678 if string(bufs[0]) != "msg" { 679 t.Fatalf("Unexpected content: %s", bufs[0]) 680 } 681 // A CLOSE should have been queued with the payload of the original close message. 682 c.mu.Lock() 683 nb, _ := c.collapsePtoNB() 684 c.mu.Unlock() 685 if n := len(nb); n == 0 { 686 t.Fatalf("Expected buffers, got %v", n) 687 } 688 if expected := 2 + 2 + len(test.payload); expected != len(nb[0]) { 689 t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) 690 } 691 b := nb[0][0] 692 if b&wsFinalBit == 0 { 693 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 694 } 695 if b&byte(wsCloseMessage) == 0 { 696 t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) 697 } 698 if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure { 699 t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status) 700 } 701 if len(test.payload) > 0 { 702 if !bytes.Equal(nb[0][4:], test.payload) { 703 t.Fatalf("Unexpected content: %s", nb[0][4:]) 704 } 705 } 706 }) 707 } 708 } 709 710 func TestWSReadControlFrameBetweebFragmentedFrames(t *testing.T) { 711 c, ri, tr := testWSSetupForRead() 712 frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("first")) 713 frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, []byte("second")) 714 ctrl := testWSCreateClientMsg(wsPongMessage, 1, true, false, nil) 715 rb := append([]byte(nil), frag1...) 716 rb = append(rb, ctrl...) 717 rb = append(rb, frag2...) 718 bufs, err := c.wsRead(ri, tr, rb) 719 if err != nil { 720 t.Fatalf("Unexpected error: %v", err) 721 } 722 if n := len(bufs); n != 2 { 723 t.Fatalf("Unexpected buffer returned: %v", n) 724 } 725 if string(bufs[0]) != "first" { 726 t.Fatalf("Unexpected content: %s", bufs[0]) 727 } 728 if string(bufs[1]) != "second" { 729 t.Fatalf("Unexpected content: %s", bufs[1]) 730 } 731 } 732 733 func TestWSCloseFrameWithPartialOrInvalid(t *testing.T) { 734 c, ri, tr := testWSSetupForRead() 735 // a close message has a status in 2 bytes + optional payload 736 payloadTxt := []byte("hello") 737 payload := make([]byte, 2+len(payloadTxt)) 738 binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure) 739 copy(payload[2:], payloadTxt) 740 closeMsg := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload) 741 742 // We will pass to wsRead a buffer of small capacity that contains 743 // only 1 byte. 744 closeFirtByte := []byte{closeMsg[0]} 745 // Make the io reader return the rest of the frame 746 tr.buf = closeMsg[1:] 747 bufs, err := c.wsRead(ri, tr, closeFirtByte[:]) 748 // It is expected that wsRead returns io.EOF on processing a close. 749 if err != io.EOF { 750 t.Fatalf("Unexpected error: %v", err) 751 } 752 if n := len(bufs); n != 0 { 753 t.Fatalf("Unexpected buffer returned: %v", n) 754 } 755 // A CLOSE should have been queued with the payload of the original close message. 756 c.mu.Lock() 757 nb, _ := c.collapsePtoNB() 758 c.mu.Unlock() 759 if n := len(nb); n == 0 { 760 t.Fatalf("Expected buffers, got %v", n) 761 } 762 if expected := 2 + 2 + len(payloadTxt); expected != len(nb[0]) { 763 t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) 764 } 765 b := nb[0][0] 766 if b&wsFinalBit == 0 { 767 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 768 } 769 if b&byte(wsCloseMessage) == 0 { 770 t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) 771 } 772 if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure { 773 t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status) 774 } 775 if !bytes.Equal(nb[0][4:], payloadTxt) { 776 t.Fatalf("Unexpected content: %s", nb[0][4:]) 777 } 778 779 // Now test close with invalid status size (1 instead of 2 bytes) 780 c, ri, tr = testWSSetupForRead() 781 payload[0] = 100 782 binary.BigEndian.PutUint16(payload, wsCloseStatusNormalClosure) 783 closeMsg = testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload[:1]) 784 785 // We will pass to wsRead a buffer of small capacity that contains 786 // only 1 byte. 787 closeFirtByte = []byte{closeMsg[0]} 788 // Make the io reader return the rest of the frame 789 tr.buf = closeMsg[1:] 790 bufs, err = c.wsRead(ri, tr, closeFirtByte[:]) 791 // It is expected that wsRead returns io.EOF on processing a close. 792 if err != io.EOF { 793 t.Fatalf("Unexpected error: %v", err) 794 } 795 if n := len(bufs); n != 0 { 796 t.Fatalf("Unexpected buffer returned: %v", n) 797 } 798 // A CLOSE should have been queued with the payload of the original close message. 799 c.mu.Lock() 800 nb, _ = c.collapsePtoNB() 801 c.mu.Unlock() 802 if n := len(nb); n == 0 { 803 t.Fatalf("Expected buffers, got %v", n) 804 } 805 if expected := 2 + 2; expected != len(nb[0]) { 806 t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) 807 } 808 b = nb[0][0] 809 if b&wsFinalBit == 0 { 810 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 811 } 812 if b&byte(wsCloseMessage) == 0 { 813 t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) 814 } 815 // Since satus was not valid, we should get wsCloseStatusNoStatusReceived 816 if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNoStatusReceived { 817 t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNoStatusReceived, status) 818 } 819 if len(nb[0][:]) != 4 { 820 t.Fatalf("Unexpected content: %s", nb[0][2:]) 821 } 822 } 823 824 func TestWSReadGetErrors(t *testing.T) { 825 tr := &testReader{err: fmt.Errorf("on purpose")} 826 for _, test := range []struct { 827 lenPayload int 828 rbextra int 829 }{ 830 {10, 1}, 831 {10, 3}, 832 {200, 1}, 833 {200, 2}, 834 {200, 5}, 835 {70000, 1}, 836 {70000, 5}, 837 {70000, 13}, 838 } { 839 t.Run("", func(t *testing.T) { 840 c, ri, _ := testWSSetupForRead() 841 msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg")) 842 frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, make([]byte, test.lenPayload)) 843 rb := append([]byte(nil), msg...) 844 rb = append(rb, frame...) 845 bufs, err := c.wsRead(ri, tr, rb[:len(msg)+test.rbextra]) 846 if err == nil || err.Error() != "on purpose" { 847 t.Fatalf("Expected 'on purpose' error, got %v", err) 848 } 849 if n := len(bufs); n != 1 { 850 t.Fatalf("Unexpected buffer returned: %v", n) 851 } 852 if string(bufs[0]) != "msg" { 853 t.Fatalf("Unexpected content: %s", bufs[0]) 854 } 855 }) 856 } 857 } 858 859 func TestWSHandleControlFrameErrors(t *testing.T) { 860 c, ri, tr := testWSSetupForRead() 861 tr.err = fmt.Errorf("on purpose") 862 863 // a close message has a status in 2 bytes + optional payload 864 text := []byte("this is a close message") 865 payload := make([]byte, 2+len(text)) 866 binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure) 867 copy(payload[2:], text) 868 ctrl := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload) 869 870 bufs, err := c.wsRead(ri, tr, ctrl[:len(ctrl)-4]) 871 if err == nil || err.Error() != "on purpose" { 872 t.Fatalf("Expected 'on purpose' error, got %v", err) 873 } 874 if n := len(bufs); n != 0 { 875 t.Fatalf("Unexpected buffer returned: %v", n) 876 } 877 878 // Alter the content of close message. It is supposed to be valid utf-8. 879 c, ri, tr = testWSSetupForRead() 880 cp := append([]byte(nil), payload...) 881 cp[10] = 0xF1 882 ctrl = testWSCreateClientMsg(wsCloseMessage, 1, true, false, cp) 883 bufs, err = c.wsRead(ri, tr, ctrl) 884 // We should still receive an EOF but the message enqueued to the client 885 // should contain wsCloseStatusInvalidPayloadData and the error about invalid utf8 886 if err != io.EOF { 887 t.Fatalf("Unexpected error: %v", err) 888 } 889 if n := len(bufs); n != 0 { 890 t.Fatalf("Unexpected buffer returned: %v", n) 891 } 892 c.mu.Lock() 893 nb, _ := c.collapsePtoNB() 894 c.mu.Unlock() 895 if n := len(nb); n == 0 { 896 t.Fatalf("Expected buffers, got %v", n) 897 } 898 b := nb[0][0] 899 if b&wsFinalBit == 0 { 900 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 901 } 902 if b&byte(wsCloseMessage) == 0 { 903 t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) 904 } 905 if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusInvalidPayloadData { 906 t.Fatalf("Expected status to be %v, got %v", wsCloseStatusInvalidPayloadData, status) 907 } 908 if !bytes.Contains(nb[0][4:], []byte("utf8")) { 909 t.Fatalf("Unexpected content: %s", nb[0][4:]) 910 } 911 } 912 913 func TestWSReadErrors(t *testing.T) { 914 for _, test := range []struct { 915 cframe func() []byte 916 err string 917 nbufs int 918 }{ 919 { 920 func() []byte { 921 msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello")) 922 msg[1] &= ^byte(wsMaskBit) 923 return msg 924 }, 925 "mask bit missing", 1, 926 }, 927 { 928 func() []byte { 929 return testWSCreateClientMsg(wsPingMessage, 1, true, false, make([]byte, 200)) 930 }, 931 "control frame length bigger than maximum allowed", 1, 932 }, 933 { 934 func() []byte { 935 return testWSCreateClientMsg(wsPingMessage, 1, false, false, []byte("hello")) 936 }, 937 "control frame does not have final bit set", 1, 938 }, 939 { 940 func() []byte { 941 frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("frag1")) 942 newMsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("new message")) 943 all := append([]byte(nil), frag1...) 944 all = append(all, newMsg...) 945 return all 946 }, 947 "new message started before final frame for previous message was received", 2, 948 }, 949 { 950 func() []byte { 951 frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("frame")) 952 frag := testWSCreateClientMsg(wsBinaryMessage, 2, false, false, []byte("continuation")) 953 all := append([]byte(nil), frame...) 954 all = append(all, frag...) 955 return all 956 }, 957 "invalid continuation frame", 2, 958 }, 959 { 960 func() []byte { 961 return testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frame")) 962 }, 963 "invalid continuation frame", 1, 964 }, 965 { 966 func() []byte { 967 return testWSCreateClientMsg(99, 1, false, false, []byte("hello")) 968 }, 969 "unknown opcode", 1, 970 }, 971 } { 972 t.Run(test.err, func(t *testing.T) { 973 c, ri, tr := testWSSetupForRead() 974 // Add a valid message first 975 msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello")) 976 // Then add the bad frame 977 bad := test.cframe() 978 // Add them both to a read buffer 979 rb := append([]byte(nil), msg...) 980 rb = append(rb, bad...) 981 bufs, err := c.wsRead(ri, tr, rb) 982 if err == nil || !strings.Contains(err.Error(), test.err) { 983 t.Fatalf("Expected error to contain %q, got %q", test.err, err.Error()) 984 } 985 if n := len(bufs); n != test.nbufs { 986 t.Fatalf("Unexpected number of buffers: %v", n) 987 } 988 if string(bufs[0]) != "hello" { 989 t.Fatalf("Unexpected content: %s", bufs[0]) 990 } 991 }) 992 } 993 } 994 995 func TestWSEnqueueCloseMsg(t *testing.T) { 996 for _, test := range []struct { 997 reason ClosedState 998 status int 999 }{ 1000 {ClientClosed, wsCloseStatusNormalClosure}, 1001 {AuthenticationTimeout, wsCloseStatusPolicyViolation}, 1002 {AuthenticationViolation, wsCloseStatusPolicyViolation}, 1003 {SlowConsumerPendingBytes, wsCloseStatusPolicyViolation}, 1004 {SlowConsumerWriteDeadline, wsCloseStatusPolicyViolation}, 1005 {MaxAccountConnectionsExceeded, wsCloseStatusPolicyViolation}, 1006 {MaxConnectionsExceeded, wsCloseStatusPolicyViolation}, 1007 {MaxControlLineExceeded, wsCloseStatusPolicyViolation}, 1008 {MaxSubscriptionsExceeded, wsCloseStatusPolicyViolation}, 1009 {MissingAccount, wsCloseStatusPolicyViolation}, 1010 {AuthenticationExpired, wsCloseStatusPolicyViolation}, 1011 {Revocation, wsCloseStatusPolicyViolation}, 1012 {TLSHandshakeError, wsCloseStatusTLSHandshake}, 1013 {ParseError, wsCloseStatusProtocolError}, 1014 {ProtocolViolation, wsCloseStatusProtocolError}, 1015 {BadClientProtocolVersion, wsCloseStatusProtocolError}, 1016 {MaxPayloadExceeded, wsCloseStatusMessageTooBig}, 1017 {ServerShutdown, wsCloseStatusGoingAway}, 1018 {WriteError, wsCloseStatusAbnormalClosure}, 1019 {ReadError, wsCloseStatusAbnormalClosure}, 1020 {StaleConnection, wsCloseStatusAbnormalClosure}, 1021 {ClosedState(254), wsCloseStatusInternalSrvError}, 1022 } { 1023 t.Run(test.reason.String(), func(t *testing.T) { 1024 c, _, _ := testWSSetupForRead() 1025 c.wsEnqueueCloseMessage(test.reason) 1026 c.mu.Lock() 1027 nb, _ := c.collapsePtoNB() 1028 c.mu.Unlock() 1029 if n := len(nb); n != 1 { 1030 t.Fatalf("Expected 1 buffer, got %v", n) 1031 } 1032 b := nb[0][0] 1033 if b&wsFinalBit == 0 { 1034 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 1035 } 1036 if b&byte(wsCloseMessage) == 0 { 1037 t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) 1038 } 1039 if status := binary.BigEndian.Uint16(nb[0][2:4]); int(status) != test.status { 1040 t.Fatalf("Expected status to be %v, got %v", test.status, status) 1041 } 1042 if string(nb[0][4:]) != test.reason.String() { 1043 t.Fatalf("Unexpected content: %s", nb[0][4:]) 1044 } 1045 }) 1046 } 1047 } 1048 1049 type testResponseWriter struct { 1050 http.ResponseWriter 1051 buf bytes.Buffer 1052 headers http.Header 1053 err error 1054 brw *bufio.ReadWriter 1055 conn *testWSFakeNetConn 1056 } 1057 1058 func (trw *testResponseWriter) Write(p []byte) (int, error) { 1059 return trw.buf.Write(p) 1060 } 1061 1062 func (trw *testResponseWriter) WriteHeader(status int) { 1063 trw.buf.WriteString(fmt.Sprintf("%v", status)) 1064 } 1065 1066 func (trw *testResponseWriter) Header() http.Header { 1067 if trw.headers == nil { 1068 trw.headers = make(http.Header) 1069 } 1070 return trw.headers 1071 } 1072 1073 type testWSFakeNetConn struct { 1074 net.Conn 1075 wbuf bytes.Buffer 1076 err error 1077 wsOpened bool 1078 isClosed bool 1079 deadlineCleared bool 1080 } 1081 1082 func (c *testWSFakeNetConn) Write(p []byte) (int, error) { 1083 if c.err != nil { 1084 return 0, c.err 1085 } 1086 return c.wbuf.Write(p) 1087 } 1088 1089 func (c *testWSFakeNetConn) SetDeadline(t time.Time) error { 1090 if t.IsZero() { 1091 c.deadlineCleared = true 1092 } 1093 return nil 1094 } 1095 1096 func (c *testWSFakeNetConn) Close() error { 1097 c.isClosed = true 1098 return nil 1099 } 1100 1101 func (trw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 1102 if trw.conn == nil { 1103 trw.conn = &testWSFakeNetConn{} 1104 } 1105 trw.conn.wsOpened = true 1106 if trw.brw == nil { 1107 trw.brw = bufio.NewReadWriter(bufio.NewReader(trw.conn), bufio.NewWriter(trw.conn)) 1108 } 1109 return trw.conn, trw.brw, trw.err 1110 } 1111 1112 func testWSOptions() *Options { 1113 opts := DefaultOptions() 1114 opts.DisableShortFirstPing = true 1115 opts.Websocket.Host = "127.0.0.1" 1116 opts.Websocket.Port = -1 1117 opts.NoSystemAccount = true 1118 var err error 1119 tc := &TLSConfigOpts{ 1120 CertFile: "./configs/certs/server.pem", 1121 KeyFile: "./configs/certs/key.pem", 1122 } 1123 opts.Websocket.TLSConfig, err = GenTLSConfig(tc) 1124 if err != nil { 1125 panic(err) 1126 } 1127 return opts 1128 } 1129 1130 func testWSCreateValidReq() *http.Request { 1131 req := &http.Request{ 1132 Method: "GET", 1133 Host: "localhost", 1134 Proto: "HTTP/1.1", 1135 } 1136 req.Header = make(http.Header) 1137 req.Header.Set("Upgrade", "websocket") 1138 req.Header.Set("Connection", "Upgrade") 1139 req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") 1140 req.Header.Set("Sec-Websocket-Version", "13") 1141 return req 1142 } 1143 1144 func TestWSCheckOrigin(t *testing.T) { 1145 notSameOrigin := false 1146 sameOrigin := true 1147 allowedListEmpty := []string{} 1148 someList := []string{"http://host1.com", "http://host2.com:1234"} 1149 1150 for _, test := range []struct { 1151 name string 1152 sameOrigin bool 1153 origins []string 1154 reqHost string 1155 reqTLS bool 1156 origin string 1157 err string 1158 }{ 1159 {"any", notSameOrigin, allowedListEmpty, "", false, "http://any.host.com", ""}, 1160 {"same origin ok", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:80", ""}, 1161 {"same origin bad host", sameOrigin, allowedListEmpty, "host.com", false, "http://other.host.com", "not same origin"}, 1162 {"same origin bad port", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:81", "not same origin"}, 1163 {"same origin bad scheme", sameOrigin, allowedListEmpty, "host.com", true, "http://host.com", "not same origin"}, 1164 {"same origin bad uri", sameOrigin, allowedListEmpty, "host.com", false, "@@@://invalid:url:1234", "invalid URI"}, 1165 {"same origin bad url", sameOrigin, allowedListEmpty, "host.com", false, "http://invalid:url:1234", "too many colons"}, 1166 {"same origin bad req host", sameOrigin, allowedListEmpty, "invalid:url:1234", false, "http://host.com", "too many colons"}, 1167 {"no origin same origin ignored", sameOrigin, allowedListEmpty, "", false, "", ""}, 1168 {"no origin list ignored", sameOrigin, someList, "", false, "", ""}, 1169 {"no origin same origin and list ignored", sameOrigin, someList, "", false, "", ""}, 1170 {"allowed from list", notSameOrigin, someList, "", false, "http://host2.com:1234", ""}, 1171 {"allowed with different path", notSameOrigin, someList, "", false, "http://host1.com/some/path", ""}, 1172 {"list bad port", notSameOrigin, someList, "", false, "http://host1.com:1234", "not in the allowed list"}, 1173 {"list bad scheme", notSameOrigin, someList, "", false, "https://host2.com:1234", "not in the allowed list"}, 1174 } { 1175 t.Run(test.name, func(t *testing.T) { 1176 opts := DefaultOptions() 1177 opts.Websocket.SameOrigin = test.sameOrigin 1178 opts.Websocket.AllowedOrigins = test.origins 1179 s := &Server{opts: opts} 1180 s.wsSetOriginOptions(&opts.Websocket) 1181 1182 req := testWSCreateValidReq() 1183 req.Host = test.reqHost 1184 if test.reqTLS { 1185 req.TLS = &tls.ConnectionState{} 1186 } 1187 if test.origin != "" { 1188 req.Header.Set("Origin", test.origin) 1189 } 1190 err := s.websocket.checkOrigin(req) 1191 if test.err == "" && err != nil { 1192 t.Fatalf("Unexpected error: %v", err) 1193 } else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) { 1194 t.Fatalf("Expected error %q, got %v", test.err, err) 1195 } 1196 }) 1197 } 1198 } 1199 1200 func TestWSUpgradeValidationErrors(t *testing.T) { 1201 for _, test := range []struct { 1202 name string 1203 setup func() (*Options, *testResponseWriter, *http.Request) 1204 err string 1205 status int 1206 }{ 1207 { 1208 "bad method", 1209 func() (*Options, *testResponseWriter, *http.Request) { 1210 opts := testWSOptions() 1211 req := testWSCreateValidReq() 1212 req.Method = "POST" 1213 return opts, nil, req 1214 }, 1215 "must be GET", 1216 http.StatusMethodNotAllowed, 1217 }, 1218 { 1219 "no host", 1220 func() (*Options, *testResponseWriter, *http.Request) { 1221 opts := testWSOptions() 1222 req := testWSCreateValidReq() 1223 req.Host = "" 1224 return opts, nil, req 1225 }, 1226 "'Host' missing in request", 1227 http.StatusBadRequest, 1228 }, 1229 { 1230 "invalid upgrade header", 1231 func() (*Options, *testResponseWriter, *http.Request) { 1232 opts := testWSOptions() 1233 req := testWSCreateValidReq() 1234 req.Header.Del("Upgrade") 1235 return opts, nil, req 1236 }, 1237 "invalid value for header 'Upgrade'", 1238 http.StatusBadRequest, 1239 }, 1240 { 1241 "invalid connection header", 1242 func() (*Options, *testResponseWriter, *http.Request) { 1243 opts := testWSOptions() 1244 req := testWSCreateValidReq() 1245 req.Header.Del("Connection") 1246 return opts, nil, req 1247 }, 1248 "invalid value for header 'Connection'", 1249 http.StatusBadRequest, 1250 }, 1251 { 1252 "no key", 1253 func() (*Options, *testResponseWriter, *http.Request) { 1254 opts := testWSOptions() 1255 req := testWSCreateValidReq() 1256 req.Header.Del("Sec-Websocket-Key") 1257 return opts, nil, req 1258 }, 1259 "key missing", 1260 http.StatusBadRequest, 1261 }, 1262 { 1263 "empty key", 1264 func() (*Options, *testResponseWriter, *http.Request) { 1265 opts := testWSOptions() 1266 req := testWSCreateValidReq() 1267 req.Header.Set("Sec-Websocket-Key", "") 1268 return opts, nil, req 1269 }, 1270 "key missing", 1271 http.StatusBadRequest, 1272 }, 1273 { 1274 "missing version", 1275 func() (*Options, *testResponseWriter, *http.Request) { 1276 opts := testWSOptions() 1277 req := testWSCreateValidReq() 1278 req.Header.Del("Sec-Websocket-Version") 1279 return opts, nil, req 1280 }, 1281 "invalid version", 1282 http.StatusBadRequest, 1283 }, 1284 { 1285 "wrong version", 1286 func() (*Options, *testResponseWriter, *http.Request) { 1287 opts := testWSOptions() 1288 req := testWSCreateValidReq() 1289 req.Header.Set("Sec-Websocket-Version", "99") 1290 return opts, nil, req 1291 }, 1292 "invalid version", 1293 http.StatusBadRequest, 1294 }, 1295 { 1296 "origin", 1297 func() (*Options, *testResponseWriter, *http.Request) { 1298 opts := testWSOptions() 1299 opts.Websocket.SameOrigin = true 1300 req := testWSCreateValidReq() 1301 req.Header.Set("Origin", "http://bad.host.com") 1302 return opts, nil, req 1303 }, 1304 "origin not allowed", 1305 http.StatusForbidden, 1306 }, 1307 { 1308 "hijack error", 1309 func() (*Options, *testResponseWriter, *http.Request) { 1310 opts := testWSOptions() 1311 rw := &testResponseWriter{err: fmt.Errorf("on purpose")} 1312 req := testWSCreateValidReq() 1313 return opts, rw, req 1314 }, 1315 "on purpose", 1316 http.StatusInternalServerError, 1317 }, 1318 { 1319 "hijack buffered data", 1320 func() (*Options, *testResponseWriter, *http.Request) { 1321 opts := testWSOptions() 1322 buf := &bytes.Buffer{} 1323 buf.WriteString("some data") 1324 rw := &testResponseWriter{ 1325 conn: &testWSFakeNetConn{}, 1326 brw: bufio.NewReadWriter(bufio.NewReader(buf), bufio.NewWriter(nil)), 1327 } 1328 tmp := [1]byte{} 1329 io.ReadAtLeast(rw.brw, tmp[:1], 1) 1330 req := testWSCreateValidReq() 1331 return opts, rw, req 1332 }, 1333 "client sent data before handshake is complete", 1334 http.StatusBadRequest, 1335 }, 1336 } { 1337 t.Run(test.name, func(t *testing.T) { 1338 opts, rw, req := test.setup() 1339 if rw == nil { 1340 rw = &testResponseWriter{} 1341 } 1342 s := &Server{opts: opts} 1343 s.wsSetOriginOptions(&opts.Websocket) 1344 res, err := s.wsUpgrade(rw, req) 1345 if err == nil || !strings.Contains(err.Error(), test.err) { 1346 t.Fatalf("Should get error %q, got %v", test.err, err) 1347 } 1348 if res != nil { 1349 t.Fatalf("Should not have returned a result, got %v", res) 1350 } 1351 expected := fmt.Sprintf("%v%s\n", test.status, http.StatusText(test.status)) 1352 if got := rw.buf.String(); got != expected { 1353 t.Fatalf("Expected %q got %q", expected, got) 1354 } 1355 // Check that if the connection was opened, it is now closed. 1356 if rw.conn != nil && rw.conn.wsOpened && !rw.conn.isClosed { 1357 t.Fatal("Connection was opened, but has not been closed") 1358 } 1359 }) 1360 } 1361 } 1362 1363 func TestWSUpgradeResponseWriteError(t *testing.T) { 1364 opts := testWSOptions() 1365 s := &Server{opts: opts} 1366 expectedErr := errors.New("on purpose") 1367 rw := &testResponseWriter{ 1368 conn: &testWSFakeNetConn{err: expectedErr}, 1369 } 1370 req := testWSCreateValidReq() 1371 res, err := s.wsUpgrade(rw, req) 1372 if err != expectedErr { 1373 t.Fatalf("Should get error %q, got %v", expectedErr.Error(), err) 1374 } 1375 if res != nil { 1376 t.Fatalf("Should not have returned a result, got %v", res) 1377 } 1378 if !rw.conn.isClosed { 1379 t.Fatal("Connection should have been closed") 1380 } 1381 } 1382 1383 func TestWSUpgradeConnDeadline(t *testing.T) { 1384 opts := testWSOptions() 1385 opts.Websocket.HandshakeTimeout = time.Second 1386 s := &Server{opts: opts} 1387 rw := &testResponseWriter{} 1388 req := testWSCreateValidReq() 1389 res, err := s.wsUpgrade(rw, req) 1390 if res == nil || err != nil { 1391 t.Fatalf("Unexpected error: %v", err) 1392 } 1393 if rw.conn.isClosed { 1394 t.Fatal("Connection should NOT have been closed") 1395 } 1396 if !rw.conn.deadlineCleared { 1397 t.Fatal("Connection deadline should have been cleared after handshake") 1398 } 1399 } 1400 1401 func TestWSCompressNegotiation(t *testing.T) { 1402 // No compression on the server, but client asks 1403 opts := testWSOptions() 1404 s := &Server{opts: opts} 1405 rw := &testResponseWriter{} 1406 req := testWSCreateValidReq() 1407 req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate") 1408 res, err := s.wsUpgrade(rw, req) 1409 if res == nil || err != nil { 1410 t.Fatalf("Unexpected error: %v", err) 1411 } 1412 // The http response should not contain "permessage-deflate" 1413 output := rw.conn.wbuf.String() 1414 if strings.Contains(output, "permessage-deflate") { 1415 t.Fatalf("Compression disabled in server so response to client should not contain extension, got %s", output) 1416 } 1417 1418 // Option in the server and client, so compression should be negotiated. 1419 s.opts.Websocket.Compression = true 1420 rw = &testResponseWriter{} 1421 res, err = s.wsUpgrade(rw, req) 1422 if res == nil || err != nil { 1423 t.Fatalf("Unexpected error: %v", err) 1424 } 1425 // The http response should not contain "permessage-deflate" 1426 output = rw.conn.wbuf.String() 1427 if !strings.Contains(output, "permessage-deflate") { 1428 t.Fatalf("Compression in server and client request, so response should contain extension, got %s", output) 1429 } 1430 1431 // Option in server but not asked by the client, so response should not contain "permessage-deflate" 1432 rw = &testResponseWriter{} 1433 req.Header.Del("Sec-Websocket-Extensions") 1434 res, err = s.wsUpgrade(rw, req) 1435 if res == nil || err != nil { 1436 t.Fatalf("Unexpected error: %v", err) 1437 } 1438 // The http response should not contain "permessage-deflate" 1439 output = rw.conn.wbuf.String() 1440 if strings.Contains(output, "permessage-deflate") { 1441 t.Fatalf("Compression in server but not in client, so response to client should not contain extension, got %s", output) 1442 } 1443 } 1444 1445 func TestWSParseOptions(t *testing.T) { 1446 for _, test := range []struct { 1447 name string 1448 content string 1449 checkOpt func(*WebsocketOpts) error 1450 err string 1451 }{ 1452 // Negative tests 1453 {"bad type", "websocket: []", nil, "to be a map"}, 1454 {"bad listen", "websocket: { listen: [] }", nil, "port or host:port"}, 1455 {"bad port", `websocket: { port: "abc" }`, nil, "not int64"}, 1456 {"bad host", `websocket: { host: 123 }`, nil, "not string"}, 1457 {"bad advertise type", `websocket: { advertise: 123 }`, nil, "not string"}, 1458 {"bad tls", `websocket: { tls: 123 }`, nil, "not map[string]interface {}"}, 1459 {"bad same origin", `websocket: { same_origin: "abc" }`, nil, "not bool"}, 1460 {"bad allowed origins type", `websocket: { allowed_origins: {} }`, nil, "unsupported type"}, 1461 {"bad allowed origins values", `websocket: { allowed_origins: [ {} ] }`, nil, "unsupported type in array"}, 1462 {"bad handshake timeout type", `websocket: { handshake_timeout: [] }`, nil, "unsupported type"}, 1463 {"bad handshake timeout duration", `websocket: { handshake_timeout: "abc" }`, nil, "invalid duration"}, 1464 {"unknown field", `websocket: { this_does_not_exist: 123 }`, nil, "unknown"}, 1465 // Positive tests 1466 {"listen port only", `websocket { listen: 1234 }`, func(wo *WebsocketOpts) error { 1467 if wo.Port != 1234 { 1468 return fmt.Errorf("expected 1234, got %v", wo.Port) 1469 } 1470 return nil 1471 }, ""}, 1472 {"listen host and port", `websocket { listen: "localhost:1234" }`, func(wo *WebsocketOpts) error { 1473 if wo.Host != "localhost" || wo.Port != 1234 { 1474 return fmt.Errorf("expected localhost:1234, got %v:%v", wo.Host, wo.Port) 1475 } 1476 return nil 1477 }, ""}, 1478 {"host", `websocket { host: "localhost" }`, func(wo *WebsocketOpts) error { 1479 if wo.Host != "localhost" { 1480 return fmt.Errorf("expected localhost, got %v", wo.Host) 1481 } 1482 return nil 1483 }, ""}, 1484 {"port", `websocket { port: 1234 }`, func(wo *WebsocketOpts) error { 1485 if wo.Port != 1234 { 1486 return fmt.Errorf("expected 1234, got %v", wo.Port) 1487 } 1488 return nil 1489 }, ""}, 1490 {"advertise", `websocket { advertise: "host:1234" }`, func(wo *WebsocketOpts) error { 1491 if wo.Advertise != "host:1234" { 1492 return fmt.Errorf("expected %q, got %q", "host:1234", wo.Advertise) 1493 } 1494 return nil 1495 }, ""}, 1496 {"same origin", `websocket { same_origin: true }`, func(wo *WebsocketOpts) error { 1497 if !wo.SameOrigin { 1498 return fmt.Errorf("expected same_origin==true, got %v", wo.SameOrigin) 1499 } 1500 return nil 1501 }, ""}, 1502 {"allowed origins one only", `websocket { allowed_origins: "https://host.com/" }`, func(wo *WebsocketOpts) error { 1503 expected := []string{"https://host.com/"} 1504 if !reflect.DeepEqual(wo.AllowedOrigins, expected) { 1505 return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins) 1506 } 1507 return nil 1508 }, ""}, 1509 {"allowed origins array", 1510 ` 1511 websocket { 1512 allowed_origins: [ 1513 "https://host1.com/" 1514 "https://host2.com/" 1515 ] 1516 } 1517 `, func(wo *WebsocketOpts) error { 1518 expected := []string{"https://host1.com/", "https://host2.com/"} 1519 if !reflect.DeepEqual(wo.AllowedOrigins, expected) { 1520 return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins) 1521 } 1522 return nil 1523 }, ""}, 1524 {"handshake timeout in whole seconds", `websocket { handshake_timeout: 3 }`, func(wo *WebsocketOpts) error { 1525 if wo.HandshakeTimeout != 3*time.Second { 1526 return fmt.Errorf("expected handshake to be 3s, got %v", wo.HandshakeTimeout) 1527 } 1528 return nil 1529 }, ""}, 1530 {"handshake timeout n duration", `websocket { handshake_timeout: "4s" }`, func(wo *WebsocketOpts) error { 1531 if wo.HandshakeTimeout != 4*time.Second { 1532 return fmt.Errorf("expected handshake to be 4s, got %v", wo.HandshakeTimeout) 1533 } 1534 return nil 1535 }, ""}, 1536 {"tls config", 1537 ` 1538 websocket { 1539 tls { 1540 cert_file: "./configs/certs/server.pem" 1541 key_file: "./configs/certs/key.pem" 1542 } 1543 } 1544 `, func(wo *WebsocketOpts) error { 1545 if wo.TLSConfig == nil { 1546 return fmt.Errorf("TLSConfig should have been set") 1547 } 1548 return nil 1549 }, ""}, 1550 {"compression", 1551 ` 1552 websocket { 1553 compression: true 1554 } 1555 `, func(wo *WebsocketOpts) error { 1556 if !wo.Compression { 1557 return fmt.Errorf("Compression should have been set") 1558 } 1559 return nil 1560 }, ""}, 1561 {"jwt cookie", 1562 ` 1563 websocket { 1564 jwt_cookie: "jwtcookie" 1565 } 1566 `, func(wo *WebsocketOpts) error { 1567 if wo.JWTCookie != "jwtcookie" { 1568 return fmt.Errorf("Invalid JWTCookie value: %q", wo.JWTCookie) 1569 } 1570 return nil 1571 }, ""}, 1572 {"no auth user", 1573 ` 1574 websocket { 1575 no_auth_user: "noauthuser" 1576 } 1577 `, func(wo *WebsocketOpts) error { 1578 if wo.NoAuthUser != "noauthuser" { 1579 return fmt.Errorf("Invalid NoAuthUser value: %q", wo.NoAuthUser) 1580 } 1581 return nil 1582 }, ""}, 1583 {"auth block", 1584 ` 1585 websocket { 1586 authorization { 1587 user: "webuser" 1588 password: "pwd" 1589 token: "token" 1590 timeout: 2.0 1591 } 1592 } 1593 `, func(wo *WebsocketOpts) error { 1594 if wo.Username != "webuser" || wo.Password != "pwd" || wo.Token != "token" || wo.AuthTimeout != 2.0 { 1595 return fmt.Errorf("Invalid auth block: %+v", wo) 1596 } 1597 return nil 1598 }, ""}, 1599 {"auth timeout as int", 1600 ` 1601 websocket { 1602 authorization { 1603 timeout: 2 1604 } 1605 } 1606 `, func(wo *WebsocketOpts) error { 1607 if wo.AuthTimeout != 2.0 { 1608 return fmt.Errorf("Invalid auth timeout: %v", wo.AuthTimeout) 1609 } 1610 return nil 1611 }, ""}, 1612 } { 1613 t.Run(test.name, func(t *testing.T) { 1614 conf := createConfFile(t, []byte(test.content)) 1615 o, err := ProcessConfigFile(conf) 1616 if test.err != _EMPTY_ { 1617 if err == nil || !strings.Contains(err.Error(), test.err) { 1618 t.Fatalf("For content: %q, expected error about %q, got %v", test.content, test.err, err) 1619 } 1620 return 1621 } else if err != nil { 1622 t.Fatalf("Unexpected error for content %q: %v", test.content, err) 1623 } 1624 if err := test.checkOpt(&o.Websocket); err != nil { 1625 t.Fatalf("Incorrect option for content %q: %v", test.content, err.Error()) 1626 } 1627 }) 1628 } 1629 } 1630 1631 func TestWSValidateOptions(t *testing.T) { 1632 nwso := DefaultOptions() 1633 wso := testWSOptions() 1634 for _, test := range []struct { 1635 name string 1636 getOpts func() *Options 1637 err string 1638 }{ 1639 {"websocket disabled", func() *Options { return nwso.Clone() }, ""}, 1640 {"no tls", func() *Options { o := wso.Clone(); o.Websocket.TLSConfig = nil; return o }, "requires TLS configuration"}, 1641 {"bad url in allowed list", func() *Options { 1642 o := wso.Clone() 1643 o.Websocket.AllowedOrigins = []string{"http://this:is:bad:url"} 1644 return o 1645 }, "unable to parse"}, 1646 {"missing trusted configuration", func() *Options { 1647 o := wso.Clone() 1648 o.Websocket.JWTCookie = "jwt" 1649 return o 1650 }, "keys configuration is required"}, 1651 {"websocket username not allowed if users specified", func() *Options { 1652 o := wso.Clone() 1653 o.Nkeys = []*NkeyUser{{Nkey: "abc"}} 1654 o.Websocket.Username = "b" 1655 o.Websocket.Password = "pwd" 1656 return o 1657 }, "websocket authentication username not compatible with presence of users/nkeys"}, 1658 {"websocket token not allowed if users specified", func() *Options { 1659 o := wso.Clone() 1660 o.Nkeys = []*NkeyUser{{Nkey: "abc"}} 1661 o.Websocket.Token = "mytoken" 1662 return o 1663 }, "websocket authentication token not compatible with presence of users/nkeys"}, 1664 } { 1665 t.Run(test.name, func(t *testing.T) { 1666 err := validateWebsocketOptions(test.getOpts()) 1667 if test.err == "" && err != nil { 1668 t.Fatalf("Unexpected error: %v", err) 1669 } else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) { 1670 t.Fatalf("Expected error to contain %q, got %v", test.err, err) 1671 } 1672 }) 1673 } 1674 } 1675 1676 func TestWSSetOriginOptions(t *testing.T) { 1677 o := testWSOptions() 1678 for _, test := range []struct { 1679 content string 1680 err string 1681 }{ 1682 {"@@@://host.com/", "invalid URI"}, 1683 {"http://this:is:bad:url/", "invalid port"}, 1684 } { 1685 t.Run(test.err, func(t *testing.T) { 1686 o.Websocket.AllowedOrigins = []string{test.content} 1687 s := &Server{} 1688 l := &captureErrorLogger{errCh: make(chan string, 1)} 1689 s.SetLogger(l, false, false) 1690 s.wsSetOriginOptions(&o.Websocket) 1691 select { 1692 case e := <-l.errCh: 1693 if !strings.Contains(e, test.err) { 1694 t.Fatalf("Unexpected error: %v", e) 1695 } 1696 case <-time.After(50 * time.Millisecond): 1697 t.Fatalf("Did not get the error") 1698 } 1699 1700 }) 1701 } 1702 } 1703 1704 type captureFatalLogger struct { 1705 DummyLogger 1706 fatalCh chan string 1707 } 1708 1709 func (l *captureFatalLogger) Fatalf(format string, v ...interface{}) { 1710 select { 1711 case l.fatalCh <- fmt.Sprintf(format, v...): 1712 default: 1713 } 1714 } 1715 1716 func TestWSFailureToStartServer(t *testing.T) { 1717 // Create a listener to use a port 1718 l, err := net.Listen("tcp", "127.0.0.1:0") 1719 if err != nil { 1720 t.Fatalf("Error listening: %v", err) 1721 } 1722 defer l.Close() 1723 1724 o := testWSOptions() 1725 // Make sure we don't have unnecessary listen ports opened. 1726 o.HTTPPort = 0 1727 o.Cluster.Port = 0 1728 o.Gateway.Name = "" 1729 o.Gateway.Port = 0 1730 o.LeafNode.Port = 0 1731 o.Websocket.Port = l.Addr().(*net.TCPAddr).Port 1732 s, err := NewServer(o) 1733 if err != nil { 1734 t.Fatalf("Error creating server: %v", err) 1735 } 1736 defer s.Shutdown() 1737 logger := &captureFatalLogger{fatalCh: make(chan string, 1)} 1738 s.SetLogger(logger, false, false) 1739 1740 wg := sync.WaitGroup{} 1741 wg.Add(1) 1742 go func() { 1743 s.Start() 1744 wg.Done() 1745 }() 1746 1747 select { 1748 case e := <-logger.fatalCh: 1749 if !strings.Contains(e, "Unable to listen") { 1750 t.Fatalf("Unexpected error: %v", e) 1751 } 1752 case <-time.After(2 * time.Second): 1753 t.Fatalf("Should have reported a fatal error") 1754 } 1755 // Since this is a test and the process does not actually 1756 // exit on Fatal error, wait for the client port to be 1757 // ready so when we shutdown we don't leave the accept 1758 // loop hanging. 1759 checkFor(t, time.Second, 15*time.Millisecond, func() error { 1760 s.mu.Lock() 1761 ready := s.listener != nil 1762 s.mu.Unlock() 1763 if !ready { 1764 return fmt.Errorf("client accept loop not started yet") 1765 } 1766 return nil 1767 }) 1768 s.Shutdown() 1769 wg.Wait() 1770 } 1771 1772 func TestWSAbnormalFailureOfWebServer(t *testing.T) { 1773 o := testWSOptions() 1774 s := RunServer(o) 1775 defer s.Shutdown() 1776 logger := &captureFatalLogger{fatalCh: make(chan string, 1)} 1777 s.SetLogger(logger, false, false) 1778 1779 // Now close the WS listener to cause a WebServer error 1780 s.mu.Lock() 1781 s.websocket.listener.Close() 1782 s.mu.Unlock() 1783 1784 select { 1785 case e := <-logger.fatalCh: 1786 if !strings.Contains(e, "websocket listener error") { 1787 t.Fatalf("Unexpected error: %v", e) 1788 } 1789 case <-time.After(2 * time.Second): 1790 t.Fatalf("Should have reported a fatal error") 1791 } 1792 } 1793 1794 type testWSClientOptions struct { 1795 compress, web bool 1796 host string 1797 port int 1798 extraHeaders map[string][]string 1799 noTLS bool 1800 path string 1801 } 1802 1803 func testNewWSClient(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte) { 1804 t.Helper() 1805 c, br, info, err := testNewWSClientWithError(t, o) 1806 if err != nil { 1807 t.Fatal(err) 1808 } 1809 return c, br, info 1810 } 1811 1812 func testNewWSClientWithError(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte, error) { 1813 addr := fmt.Sprintf("%s:%d", o.host, o.port) 1814 wsc, err := net.Dial("tcp", addr) 1815 if err != nil { 1816 return nil, nil, nil, fmt.Errorf("Error creating ws connection: %v", err) 1817 } 1818 if !o.noTLS { 1819 wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) 1820 wsc.SetDeadline(time.Now().Add(time.Second)) 1821 if err := wsc.(*tls.Conn).Handshake(); err != nil { 1822 return nil, nil, nil, fmt.Errorf("Error during handshake: %v", err) 1823 } 1824 wsc.SetDeadline(time.Time{}) 1825 } 1826 req := testWSCreateValidReq() 1827 if o.compress { 1828 req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate") 1829 } 1830 if o.web { 1831 req.Header.Set("User-Agent", "Mozilla/5.0") 1832 } 1833 if len(o.extraHeaders) > 0 { 1834 for hdr, values := range o.extraHeaders { 1835 if len(values) == 0 { 1836 req.Header.Set(hdr, _EMPTY_) 1837 continue 1838 } 1839 req.Header.Set(hdr, values[0]) 1840 for i := 1; i < len(values); i++ { 1841 req.Header.Add(hdr, values[i]) 1842 } 1843 } 1844 } 1845 req.URL, _ = url.Parse("wss://" + addr + o.path) 1846 if err := req.Write(wsc); err != nil { 1847 return nil, nil, nil, fmt.Errorf("Error sending request: %v", err) 1848 } 1849 br := bufio.NewReader(wsc) 1850 resp, err := http.ReadResponse(br, req) 1851 if err != nil { 1852 return nil, nil, nil, fmt.Errorf("Error reading response: %v", err) 1853 } 1854 defer resp.Body.Close() 1855 if resp.StatusCode != http.StatusSwitchingProtocols { 1856 return nil, nil, nil, fmt.Errorf("Expected response status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) 1857 } 1858 var info []byte 1859 if o.path == mqttWSPath { 1860 if v := resp.Header[wsSecProto]; len(v) != 1 || v[0] != wsMQTTSecProtoVal { 1861 return nil, nil, nil, fmt.Errorf("No mqtt protocol in header: %v", resp.Header) 1862 } 1863 } else { 1864 // Wait for the INFO 1865 info = testWSReadFrame(t, br) 1866 if !bytes.HasPrefix(info, []byte("INFO {")) { 1867 return nil, nil, nil, fmt.Errorf("Expected INFO, got %s", info) 1868 } 1869 } 1870 return wsc, br, info, nil 1871 } 1872 1873 type testClaimsOptions struct { 1874 nac *jwt.AccountClaims 1875 nuc *jwt.UserClaims 1876 connectRequest interface{} 1877 dontSign bool 1878 expectAnswer string 1879 } 1880 1881 func testWSWithClaims(t *testing.T, s *Server, o testWSClientOptions, tclm testClaimsOptions) (kp nkeys.KeyPair, conn net.Conn, rdr *bufio.Reader, auth_was_required bool) { 1882 t.Helper() 1883 1884 okp, _ := nkeys.FromSeed(oSeed) 1885 1886 akp, _ := nkeys.CreateAccount() 1887 apub, _ := akp.PublicKey() 1888 if tclm.nac == nil { 1889 tclm.nac = jwt.NewAccountClaims(apub) 1890 } else { 1891 tclm.nac.Subject = apub 1892 } 1893 ajwt, err := tclm.nac.Encode(okp) 1894 if err != nil { 1895 t.Fatalf("Error generating account JWT: %v", err) 1896 } 1897 1898 nkp, _ := nkeys.CreateUser() 1899 pub, _ := nkp.PublicKey() 1900 if tclm.nuc == nil { 1901 tclm.nuc = jwt.NewUserClaims(pub) 1902 } else { 1903 tclm.nuc.Subject = pub 1904 } 1905 jwt, err := tclm.nuc.Encode(akp) 1906 if err != nil { 1907 t.Fatalf("Error generating user JWT: %v", err) 1908 } 1909 1910 addAccountToMemResolver(s, apub, ajwt) 1911 1912 c, cr, l := testNewWSClient(t, o) 1913 1914 var info struct { 1915 Nonce string `json:"nonce,omitempty"` 1916 AuthRequired bool `json:"auth_required,omitempty"` 1917 } 1918 1919 if err := json.Unmarshal([]byte(l[5:]), &info); err != nil { 1920 t.Fatal(err) 1921 } 1922 if info.AuthRequired { 1923 cs := "" 1924 if tclm.connectRequest != nil { 1925 customReq, err := json.Marshal(tclm.connectRequest) 1926 if err != nil { 1927 t.Fatal(err) 1928 } 1929 // PING needed to flush the +OK/-ERR to us. 1930 cs = fmt.Sprintf("CONNECT %v\r\nPING\r\n", string(customReq)) 1931 } else if !tclm.dontSign { 1932 // Sign Nonce 1933 sigraw, _ := nkp.Sign([]byte(info.Nonce)) 1934 sig := base64.RawURLEncoding.EncodeToString(sigraw) 1935 cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"sig\":\"%s\",\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt, sig) 1936 } else { 1937 cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt) 1938 } 1939 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(cs)) 1940 c.Write(wsmsg) 1941 l = testWSReadFrame(t, cr) 1942 if !strings.HasPrefix(string(l), tclm.expectAnswer) { 1943 t.Fatalf("Expected %q, got %q", tclm.expectAnswer, l) 1944 } 1945 } 1946 return akp, c, cr, info.AuthRequired 1947 } 1948 1949 func setupAddTrusted(o *Options) { 1950 kp, _ := nkeys.FromSeed(oSeed) 1951 pub, _ := kp.PublicKey() 1952 o.TrustedKeys = []string{pub} 1953 } 1954 1955 func setupAddCookie(o *Options) { 1956 o.Websocket.JWTCookie = "jwt" 1957 } 1958 1959 func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int, cookies ...string) (net.Conn, *bufio.Reader, []byte) { 1960 t.Helper() 1961 opts := testWSClientOptions{ 1962 compress: compress, 1963 web: web, 1964 host: host, 1965 port: port, 1966 } 1967 1968 if len(cookies) > 0 { 1969 opts.extraHeaders = map[string][]string{} 1970 opts.extraHeaders["Cookie"] = cookies 1971 } 1972 return testNewWSClient(t, opts) 1973 } 1974 1975 func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) { 1976 wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port) 1977 // Send CONNECT and PING 1978 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n")) 1979 if _, err := wsc.Write(wsmsg); err != nil { 1980 t.Fatalf("Error sending message: %v", err) 1981 } 1982 // Wait for the PONG 1983 if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 1984 t.Fatalf("Expected PONG, got %s", msg) 1985 } 1986 return wsc, br 1987 } 1988 1989 func testWSReadFrame(t testing.TB, br *bufio.Reader) []byte { 1990 t.Helper() 1991 fh := [2]byte{} 1992 if _, err := io.ReadAtLeast(br, fh[:2], 2); err != nil { 1993 t.Fatalf("Error reading frame: %v", err) 1994 } 1995 fc := fh[0]&wsRsv1Bit != 0 1996 sb := fh[1] 1997 size := 0 1998 switch { 1999 case sb <= 125: 2000 size = int(sb) 2001 case sb == 126: 2002 tmp := [2]byte{} 2003 if _, err := io.ReadAtLeast(br, tmp[:2], 2); err != nil { 2004 t.Fatalf("Error reading frame: %v", err) 2005 } 2006 size = int(binary.BigEndian.Uint16(tmp[:2])) 2007 case sb == 127: 2008 tmp := [8]byte{} 2009 if _, err := io.ReadAtLeast(br, tmp[:8], 8); err != nil { 2010 t.Fatalf("Error reading frame: %v", err) 2011 } 2012 size = int(binary.BigEndian.Uint64(tmp[:8])) 2013 } 2014 buf := make([]byte, size) 2015 if _, err := io.ReadAtLeast(br, buf, size); err != nil { 2016 t.Fatalf("Error reading frame: %v", err) 2017 } 2018 if !fc { 2019 return buf 2020 } 2021 buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff) 2022 dbr := bytes.NewBuffer(buf) 2023 d := flate.NewReader(dbr) 2024 uncompressed, err := io.ReadAll(d) 2025 if err != nil { 2026 t.Fatalf("Error reading frame: %v", err) 2027 } 2028 return uncompressed 2029 } 2030 2031 func TestWSPubSub(t *testing.T) { 2032 for _, test := range []struct { 2033 name string 2034 compression bool 2035 }{ 2036 {"no compression", false}, 2037 {"compression", true}, 2038 } { 2039 t.Run(test.name, func(t *testing.T) { 2040 o := testWSOptions() 2041 if test.compression { 2042 o.Websocket.Compression = true 2043 } 2044 s := RunServer(o) 2045 defer s.Shutdown() 2046 2047 // Create a regular client to subscribe 2048 nc := natsConnect(t, s.ClientURL()) 2049 defer nc.Close() 2050 nsub := natsSubSync(t, nc, "foo") 2051 checkExpectedSubs(t, 1, s) 2052 2053 // Now create a WS client and send a message on "foo" 2054 wsc, br := testWSCreateClient(t, test.compression, false, o.Websocket.Host, o.Websocket.Port) 2055 defer wsc.Close() 2056 2057 // Send a WS message for "PUB foo 2\r\nok\r\n" 2058 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n")) 2059 if _, err := wsc.Write(wsmsg); err != nil { 2060 t.Fatalf("Error sending message: %v", err) 2061 } 2062 2063 // Now check that message is received 2064 msg := natsNexMsg(t, nsub, time.Second) 2065 if string(msg.Data) != "from ws" { 2066 t.Fatalf("Expected message to be %q, got %q", "ok", string(msg.Data)) 2067 } 2068 2069 // Now do reverse, create a subscription on WS client on bar 2070 wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB bar 1\r\n")) 2071 if _, err := wsc.Write(wsmsg); err != nil { 2072 t.Fatalf("Error sending subscription: %v", err) 2073 } 2074 // Wait for it to be registered on server 2075 checkExpectedSubs(t, 2, s) 2076 // Now publish from NATS connection and verify received on WS client 2077 natsPub(t, nc, "bar", []byte("from nats")) 2078 natsFlush(t, nc) 2079 2080 // Check for the "from nats" message... 2081 // Set some deadline so we are not stuck forever on failure 2082 wsc.SetReadDeadline(time.Now().Add(10 * time.Second)) 2083 ok := 0 2084 for { 2085 line, _, err := br.ReadLine() 2086 if err != nil { 2087 t.Fatalf("Error reading: %v", err) 2088 } 2089 // Note that this works even in compression test because those 2090 // texts are likely not to be compressed, but compression code is 2091 // still executed. 2092 if ok == 0 && bytes.Contains(line, []byte("MSG bar 1 9")) { 2093 ok = 1 2094 continue 2095 } else if ok == 1 && bytes.Contains(line, []byte("from nats")) { 2096 break 2097 } 2098 } 2099 }) 2100 } 2101 } 2102 2103 func TestWSTLSConnection(t *testing.T) { 2104 o := testWSOptions() 2105 s := RunServer(o) 2106 defer s.Shutdown() 2107 2108 addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) 2109 2110 for _, test := range []struct { 2111 name string 2112 useTLS bool 2113 status int 2114 }{ 2115 {"client uses TLS", true, http.StatusSwitchingProtocols}, 2116 {"client does not use TLS", false, http.StatusBadRequest}, 2117 } { 2118 t.Run(test.name, func(t *testing.T) { 2119 wsc, err := net.Dial("tcp", addr) 2120 if err != nil { 2121 t.Fatalf("Error creating ws connection: %v", err) 2122 } 2123 defer wsc.Close() 2124 if test.useTLS { 2125 wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) 2126 if err := wsc.(*tls.Conn).Handshake(); err != nil { 2127 t.Fatalf("Error during handshake: %v", err) 2128 } 2129 } 2130 req := testWSCreateValidReq() 2131 var scheme string 2132 if test.useTLS { 2133 scheme = "s" 2134 } 2135 req.URL, _ = url.Parse("ws" + scheme + "://" + addr) 2136 if err := req.Write(wsc); err != nil { 2137 t.Fatalf("Error sending request: %v", err) 2138 } 2139 br := bufio.NewReader(wsc) 2140 resp, err := http.ReadResponse(br, req) 2141 if err != nil { 2142 t.Fatalf("Error reading response: %v", err) 2143 } 2144 defer resp.Body.Close() 2145 if resp.StatusCode != test.status { 2146 t.Fatalf("Expected status %v, got %v", test.status, resp.StatusCode) 2147 } 2148 }) 2149 } 2150 } 2151 2152 func TestWSTLSVerifyClientCert(t *testing.T) { 2153 o := testWSOptions() 2154 tc := &TLSConfigOpts{ 2155 CertFile: "../test/configs/certs/server-cert.pem", 2156 KeyFile: "../test/configs/certs/server-key.pem", 2157 CaFile: "../test/configs/certs/ca.pem", 2158 Verify: true, 2159 } 2160 tlsc, err := GenTLSConfig(tc) 2161 if err != nil { 2162 t.Fatalf("Error creating tls config: %v", err) 2163 } 2164 o.Websocket.TLSConfig = tlsc 2165 s := RunServer(o) 2166 defer s.Shutdown() 2167 2168 addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) 2169 2170 for _, test := range []struct { 2171 name string 2172 provideCert bool 2173 }{ 2174 {"client provides cert", true}, 2175 {"client does not provide cert", false}, 2176 } { 2177 t.Run(test.name, func(t *testing.T) { 2178 wsc, err := net.Dial("tcp", addr) 2179 if err != nil { 2180 t.Fatalf("Error creating ws connection: %v", err) 2181 } 2182 defer wsc.Close() 2183 tlsc := &tls.Config{} 2184 if test.provideCert { 2185 tc := &TLSConfigOpts{ 2186 CertFile: "../test/configs/certs/client-cert.pem", 2187 KeyFile: "../test/configs/certs/client-key.pem", 2188 } 2189 var err error 2190 tlsc, err = GenTLSConfig(tc) 2191 if err != nil { 2192 t.Fatalf("Error generating tls config: %v", err) 2193 } 2194 } 2195 tlsc.InsecureSkipVerify = true 2196 wsc = tls.Client(wsc, tlsc) 2197 if err := wsc.(*tls.Conn).Handshake(); err != nil { 2198 t.Fatalf("Error during handshake: %v", err) 2199 } 2200 req := testWSCreateValidReq() 2201 req.URL, _ = url.Parse("wss://" + addr) 2202 if err := req.Write(wsc); err != nil { 2203 t.Fatalf("Error sending request: %v", err) 2204 } 2205 br := bufio.NewReader(wsc) 2206 resp, err := http.ReadResponse(br, req) 2207 if resp != nil { 2208 resp.Body.Close() 2209 } 2210 if !test.provideCert { 2211 if err == nil { 2212 t.Fatal("Expected error, did not get one") 2213 } else if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "certificate required") { 2214 t.Fatalf("Unexpected error: %v", err) 2215 } 2216 return 2217 } 2218 if err != nil { 2219 t.Fatalf("Unexpected error: %v", err) 2220 } 2221 if resp.StatusCode != http.StatusSwitchingProtocols { 2222 t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) 2223 } 2224 }) 2225 } 2226 } 2227 2228 func testCreateAllowedConnectionTypes(list []string) map[string]struct{} { 2229 if len(list) == 0 { 2230 return nil 2231 } 2232 m := make(map[string]struct{}, len(list)) 2233 for _, l := range list { 2234 m[l] = struct{}{} 2235 } 2236 return m 2237 } 2238 2239 func TestWSTLSVerifyAndMap(t *testing.T) { 2240 accName := "MyAccount" 2241 acc := NewAccount(accName) 2242 certUserName := "CN=example.com,OU=NATS.io" 2243 users := []*User{{Username: certUserName, Account: acc}} 2244 2245 for _, test := range []struct { 2246 name string 2247 filtering bool 2248 provideCert bool 2249 }{ 2250 {"no filtering, client provides cert", false, true}, 2251 {"no filtering, client does not provide cert", false, false}, 2252 {"filtering, client provides cert", true, true}, 2253 {"filtering, client does not provide cert", true, false}, 2254 {"no users override, client provides cert", false, true}, 2255 {"no users override, client does not provide cert", false, false}, 2256 {"users override, client provides cert", true, true}, 2257 {"users override, client does not provide cert", true, false}, 2258 } { 2259 t.Run(test.name, func(t *testing.T) { 2260 o := testWSOptions() 2261 o.Accounts = []*Account{acc} 2262 o.Users = users 2263 if test.filtering { 2264 o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}) 2265 } 2266 tc := &TLSConfigOpts{ 2267 CertFile: "../test/configs/certs/tlsauth/server.pem", 2268 KeyFile: "../test/configs/certs/tlsauth/server-key.pem", 2269 CaFile: "../test/configs/certs/tlsauth/ca.pem", 2270 Verify: true, 2271 } 2272 tlsc, err := GenTLSConfig(tc) 2273 if err != nil { 2274 t.Fatalf("Error creating tls config: %v", err) 2275 } 2276 o.Websocket.TLSConfig = tlsc 2277 o.Websocket.TLSMap = true 2278 s := RunServer(o) 2279 defer s.Shutdown() 2280 2281 addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) 2282 wsc, err := net.Dial("tcp", addr) 2283 if err != nil { 2284 t.Fatalf("Error creating ws connection: %v", err) 2285 } 2286 defer wsc.Close() 2287 tlscc := &tls.Config{} 2288 if test.provideCert { 2289 tc := &TLSConfigOpts{ 2290 CertFile: "../test/configs/certs/tlsauth/client.pem", 2291 KeyFile: "../test/configs/certs/tlsauth/client-key.pem", 2292 } 2293 var err error 2294 tlscc, err = GenTLSConfig(tc) 2295 if err != nil { 2296 t.Fatalf("Error generating tls config: %v", err) 2297 } 2298 } 2299 tlscc.InsecureSkipVerify = true 2300 wsc = tls.Client(wsc, tlscc) 2301 if err := wsc.(*tls.Conn).Handshake(); err != nil { 2302 t.Fatalf("Error during handshake: %v", err) 2303 } 2304 req := testWSCreateValidReq() 2305 req.URL, _ = url.Parse("wss://" + addr) 2306 if err := req.Write(wsc); err != nil { 2307 t.Fatalf("Error sending request: %v", err) 2308 } 2309 br := bufio.NewReader(wsc) 2310 resp, err := http.ReadResponse(br, req) 2311 if resp != nil { 2312 resp.Body.Close() 2313 } 2314 if !test.provideCert { 2315 if err == nil { 2316 t.Fatal("Expected error, did not get one") 2317 } else if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "certificate required") { 2318 t.Fatalf("Unexpected error: %v", err) 2319 } 2320 return 2321 } 2322 if err != nil { 2323 t.Fatalf("Unexpected error: %v", err) 2324 } 2325 if resp.StatusCode != http.StatusSwitchingProtocols { 2326 t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) 2327 } 2328 // Wait for the INFO 2329 l := testWSReadFrame(t, br) 2330 if !bytes.HasPrefix(l, []byte("INFO {")) { 2331 t.Fatalf("Expected INFO, got %s", l) 2332 } 2333 var info serverInfo 2334 if err := json.Unmarshal(l[5:], &info); err != nil { 2335 t.Fatalf("Unable to unmarshal info: %v", err) 2336 } 2337 // Send CONNECT and PING 2338 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n")) 2339 if _, err := wsc.Write(wsmsg); err != nil { 2340 t.Fatalf("Error sending message: %v", err) 2341 } 2342 // Wait for the PONG 2343 if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 2344 t.Fatalf("Expected PONG, got %s", msg) 2345 } 2346 2347 var uname string 2348 var accname string 2349 c := s.getClient(info.CID) 2350 if c != nil { 2351 c.mu.Lock() 2352 uname = c.opts.Username 2353 if c.acc != nil { 2354 accname = c.acc.GetName() 2355 } 2356 c.mu.Unlock() 2357 } 2358 if uname != certUserName { 2359 t.Fatalf("Expected username %q, got %q", certUserName, uname) 2360 } 2361 if accname != accName { 2362 t.Fatalf("Expected account %q, got %v", accName, accname) 2363 } 2364 }) 2365 } 2366 } 2367 2368 func TestWSHandshakeTimeout(t *testing.T) { 2369 o := testWSOptions() 2370 o.Websocket.HandshakeTimeout = time.Millisecond 2371 tc := &TLSConfigOpts{ 2372 CertFile: "./configs/certs/server.pem", 2373 KeyFile: "./configs/certs/key.pem", 2374 } 2375 o.Websocket.TLSConfig, _ = GenTLSConfig(tc) 2376 s := RunServer(o) 2377 defer s.Shutdown() 2378 2379 logger := &captureErrorLogger{errCh: make(chan string, 1)} 2380 s.SetLogger(logger, false, false) 2381 2382 addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) 2383 wsc, err := net.Dial("tcp", addr) 2384 if err != nil { 2385 t.Fatalf("Error creating ws connection: %v", err) 2386 } 2387 defer wsc.Close() 2388 2389 // Delay the handshake 2390 wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) 2391 time.Sleep(20 * time.Millisecond) 2392 // We expect error since the server should have cut us off 2393 if err := wsc.(*tls.Conn).Handshake(); err == nil { 2394 t.Fatal("Expected error during handshake") 2395 } 2396 2397 // Check that server logs error 2398 select { 2399 case e := <-logger.errCh: 2400 // Check that log starts with "websocket: " 2401 if !strings.HasPrefix(e, "websocket: ") { 2402 t.Fatalf("Wrong log line start: %s", e) 2403 } 2404 if !strings.Contains(e, "timeout") { 2405 t.Fatalf("Unexpected error: %v", e) 2406 } 2407 case <-time.After(time.Second): 2408 t.Fatalf("Should have timed-out") 2409 } 2410 } 2411 2412 func TestWSServerReportUpgradeFailure(t *testing.T) { 2413 o := testWSOptions() 2414 s := RunServer(o) 2415 defer s.Shutdown() 2416 2417 logger := &captureErrorLogger{errCh: make(chan string, 1)} 2418 s.SetLogger(logger, false, false) 2419 2420 addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port) 2421 req := testWSCreateValidReq() 2422 req.URL, _ = url.Parse("wss://" + addr) 2423 2424 wsc, err := net.Dial("tcp", addr) 2425 if err != nil { 2426 t.Fatalf("Error creating ws connection: %v", err) 2427 } 2428 defer wsc.Close() 2429 wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) 2430 if err := wsc.(*tls.Conn).Handshake(); err != nil { 2431 t.Fatalf("Error during handshake: %v", err) 2432 } 2433 // Remove a required field from the request to have it fail 2434 req.Header.Del("Connection") 2435 // Send the request 2436 if err := req.Write(wsc); err != nil { 2437 t.Fatalf("Error sending request: %v", err) 2438 } 2439 br := bufio.NewReader(wsc) 2440 resp, err := http.ReadResponse(br, req) 2441 if err != nil { 2442 t.Fatalf("Error reading response: %v", err) 2443 } 2444 defer resp.Body.Close() 2445 if resp.StatusCode != http.StatusBadRequest { 2446 t.Fatalf("Expected status %v, got %v", http.StatusBadRequest, resp.StatusCode) 2447 } 2448 2449 // Check that server logs error 2450 select { 2451 case e := <-logger.errCh: 2452 if !strings.Contains(e, "invalid value for header 'Connection'") { 2453 t.Fatalf("Unexpected error: %v", e) 2454 } 2455 // The client IP's local should be printed as a remote from server perspective. 2456 clientIP := wsc.LocalAddr().String() 2457 if !strings.HasPrefix(e, clientIP) { 2458 t.Fatalf("IP should have been logged, it was not: %v", e) 2459 } 2460 case <-time.After(time.Second): 2461 t.Fatalf("Should have timed-out") 2462 } 2463 } 2464 2465 func TestWSCloseMsgSendOnConnectionClose(t *testing.T) { 2466 o := testWSOptions() 2467 s := RunServer(o) 2468 defer s.Shutdown() 2469 2470 wsc, br := testWSCreateClient(t, false, false, o.Websocket.Host, o.Websocket.Port) 2471 defer wsc.Close() 2472 2473 checkClientsCount(t, s, 1) 2474 var c *client 2475 s.mu.Lock() 2476 for _, cli := range s.clients { 2477 c = cli 2478 break 2479 } 2480 s.mu.Unlock() 2481 2482 c.closeConnection(ProtocolViolation) 2483 msg := testWSReadFrame(t, br) 2484 if len(msg) < 2 { 2485 t.Fatalf("Should have 2 bytes to represent the status, got %v", msg) 2486 } 2487 if sc := int(binary.BigEndian.Uint16(msg[:2])); sc != wsCloseStatusProtocolError { 2488 t.Fatalf("Expected status to be %v, got %v", wsCloseStatusProtocolError, sc) 2489 } 2490 expectedPayload := ProtocolViolation.String() 2491 if p := string(msg[2:]); p != expectedPayload { 2492 t.Fatalf("Expected payload to be %q, got %q", expectedPayload, p) 2493 } 2494 } 2495 2496 func TestWSAdvertise(t *testing.T) { 2497 o := testWSOptions() 2498 o.Cluster.Port = 0 2499 o.HTTPPort = 0 2500 o.Websocket.Advertise = "xxx:host:yyy" 2501 s, err := NewServer(o) 2502 if err != nil { 2503 t.Fatalf("Unexpected error: %v", err) 2504 } 2505 defer s.Shutdown() 2506 l := &captureFatalLogger{fatalCh: make(chan string, 1)} 2507 s.SetLogger(l, false, false) 2508 s.Start() 2509 select { 2510 case e := <-l.fatalCh: 2511 if !strings.Contains(e, "Unable to get websocket connect URLs") { 2512 t.Fatalf("Unexpected error: %q", e) 2513 } 2514 case <-time.After(time.Second): 2515 t.Fatal("Should have failed to start") 2516 } 2517 s.Shutdown() 2518 2519 o1 := testWSOptions() 2520 o1.Websocket.Advertise = "host1:1234" 2521 s1 := RunServer(o1) 2522 defer s1.Shutdown() 2523 2524 wsc, br := testWSCreateClient(t, false, false, o1.Websocket.Host, o1.Websocket.Port) 2525 defer wsc.Close() 2526 2527 o2 := testWSOptions() 2528 o2.Websocket.Advertise = "host2:5678" 2529 o2.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", o1.Cluster.Host, o1.Cluster.Port)) 2530 s2 := RunServer(o2) 2531 defer s2.Shutdown() 2532 2533 checkInfo := func(expected []string) { 2534 t.Helper() 2535 infob := testWSReadFrame(t, br) 2536 info := &Info{} 2537 json.Unmarshal(infob[5:], info) 2538 if n := len(info.ClientConnectURLs); n != len(expected) { 2539 t.Fatalf("Unexpected info: %+v", info) 2540 } 2541 good := 0 2542 for _, u := range info.ClientConnectURLs { 2543 for _, eu := range expected { 2544 if u == eu { 2545 good++ 2546 } 2547 } 2548 } 2549 if good != len(expected) { 2550 t.Fatalf("Unexpected connect urls: %q", info.ClientConnectURLs) 2551 } 2552 } 2553 checkInfo([]string{"host1:1234", "host2:5678"}) 2554 2555 // Now shutdown s2 and expect another INFO 2556 s2.Shutdown() 2557 checkInfo([]string{"host1:1234"}) 2558 2559 // Restart with another advertise and check that it gets updated 2560 o2.Websocket.Advertise = "host3:9012" 2561 s2 = RunServer(o2) 2562 defer s2.Shutdown() 2563 checkInfo([]string{"host1:1234", "host3:9012"}) 2564 } 2565 2566 func TestWSFrameOutbound(t *testing.T) { 2567 for _, test := range []struct { 2568 name string 2569 maskingWrite bool 2570 }{ 2571 {"no write masking", false}, 2572 {"write masking", true}, 2573 } { 2574 t.Run(test.name, func(t *testing.T) { 2575 c, _, _ := testWSSetupForRead() 2576 c.ws.maskwrite = test.maskingWrite 2577 2578 getKey := func(buf []byte) []byte { 2579 return buf[len(buf)-4:] 2580 } 2581 2582 var bufs net.Buffers 2583 bufs = append(bufs, []byte("this ")) 2584 bufs = append(bufs, []byte("is ")) 2585 bufs = append(bufs, []byte("a ")) 2586 bufs = append(bufs, []byte("set ")) 2587 bufs = append(bufs, []byte("of ")) 2588 bufs = append(bufs, []byte("buffers")) 2589 en := 2 2590 for _, b := range bufs { 2591 en += len(b) 2592 } 2593 if test.maskingWrite { 2594 en += 4 2595 } 2596 c.mu.Lock() 2597 c.out.nb = bufs 2598 res, n := c.collapsePtoNB() 2599 c.mu.Unlock() 2600 if n != int64(en) { 2601 t.Fatalf("Expected size to be %v, got %v", en, n) 2602 } 2603 if eb := 1 + len(bufs); eb != len(res) { 2604 t.Fatalf("Expected %v buffers, got %v", eb, len(res)) 2605 } 2606 var ob []byte 2607 for i := 1; i < len(res); i++ { 2608 ob = append(ob, res[i]...) 2609 } 2610 if test.maskingWrite { 2611 wsMaskBuf(getKey(res[0]), ob) 2612 } 2613 if !bytes.Equal(ob, []byte("this is a set of buffers")) { 2614 t.Fatalf("Unexpected outbound: %q", ob) 2615 } 2616 2617 bufs = nil 2618 c.out.pb = 0 2619 c.ws.fs = 0 2620 c.ws.frames = nil 2621 c.ws.browser = true 2622 bufs = append(bufs, []byte("some smaller ")) 2623 bufs = append(bufs, []byte("buffers")) 2624 bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+10)) 2625 bufs = append(bufs, []byte("then some more")) 2626 en = 2 + len(bufs[0]) + len(bufs[1]) 2627 en += 4 + len(bufs[2]) - 10 2628 en += 2 + len(bufs[3]) + 10 2629 c.mu.Lock() 2630 c.out.nb = bufs 2631 res, n = c.collapsePtoNB() 2632 c.mu.Unlock() 2633 if test.maskingWrite { 2634 en += 3 * 4 2635 } 2636 if n != int64(en) { 2637 t.Fatalf("Expected size to be %v, got %v", en, n) 2638 } 2639 if len(res) != 8 { 2640 t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) 2641 } 2642 if len(res[4]) != wsFrameSizeForBrowsers { 2643 t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) 2644 } 2645 if len(res[6]) != 10 { 2646 t.Fatalf("Frame 6 should have the partial of 10 bytes, got %v", len(res[6])) 2647 } 2648 if test.maskingWrite { 2649 b := &bytes.Buffer{} 2650 key := getKey(res[0]) 2651 b.Write(res[1]) 2652 b.Write(res[2]) 2653 ud := b.Bytes() 2654 wsMaskBuf(key, ud) 2655 if string(ud) != "some smaller buffers" { 2656 t.Fatalf("Unexpected result: %q", ud) 2657 } 2658 2659 b.Reset() 2660 key = getKey(res[3]) 2661 b.Write(res[4]) 2662 ud = b.Bytes() 2663 wsMaskBuf(key, ud) 2664 for i := 0; i < len(ud); i++ { 2665 if ud[i] != 0 { 2666 t.Fatalf("Unexpected result: %v", ud) 2667 } 2668 } 2669 2670 b.Reset() 2671 key = getKey(res[5]) 2672 b.Write(res[6]) 2673 b.Write(res[7]) 2674 ud = b.Bytes() 2675 wsMaskBuf(key, ud) 2676 for i := 0; i < len(ud[:10]); i++ { 2677 if ud[i] != 0 { 2678 t.Fatalf("Unexpected result: %v", ud[:10]) 2679 } 2680 } 2681 if string(ud[10:]) != "then some more" { 2682 t.Fatalf("Unexpected result: %q", ud[10:]) 2683 } 2684 } 2685 2686 bufs = nil 2687 c.out.pb = 0 2688 c.ws.fs = 0 2689 c.ws.frames = nil 2690 c.ws.browser = true 2691 bufs = append(bufs, []byte("some smaller ")) 2692 bufs = append(bufs, []byte("buffers")) 2693 // Have one of the exact max size 2694 bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers)) 2695 bufs = append(bufs, []byte("then some more")) 2696 en = 2 + len(bufs[0]) + len(bufs[1]) 2697 en += 4 + len(bufs[2]) 2698 en += 2 + len(bufs[3]) 2699 c.mu.Lock() 2700 c.out.nb = bufs 2701 res, n = c.collapsePtoNB() 2702 c.mu.Unlock() 2703 if test.maskingWrite { 2704 en += 3 * 4 2705 } 2706 if n != int64(en) { 2707 t.Fatalf("Expected size to be %v, got %v", en, n) 2708 } 2709 if len(res) != 7 { 2710 t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) 2711 } 2712 if len(res[4]) != wsFrameSizeForBrowsers { 2713 t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) 2714 } 2715 if test.maskingWrite { 2716 key := getKey(res[5]) 2717 wsMaskBuf(key, res[6]) 2718 } 2719 if string(res[6]) != "then some more" { 2720 t.Fatalf("Frame 6 incorrect: %q", res[6]) 2721 } 2722 2723 bufs = nil 2724 c.out.pb = 0 2725 c.ws.fs = 0 2726 c.ws.frames = nil 2727 c.ws.browser = true 2728 bufs = append(bufs, []byte("some smaller ")) 2729 bufs = append(bufs, []byte("buffers")) 2730 // Have one of the exact max size, and last in the list 2731 bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers)) 2732 en = 2 + len(bufs[0]) + len(bufs[1]) 2733 en += 4 + len(bufs[2]) 2734 c.mu.Lock() 2735 c.out.nb = bufs 2736 res, n = c.collapsePtoNB() 2737 c.mu.Unlock() 2738 if test.maskingWrite { 2739 en += 2 * 4 2740 } 2741 if n != int64(en) { 2742 t.Fatalf("Expected size to be %v, got %v", en, n) 2743 } 2744 if len(res) != 5 { 2745 t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) 2746 } 2747 if len(res[4]) != wsFrameSizeForBrowsers { 2748 t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) 2749 } 2750 2751 bufs = nil 2752 c.out.pb = 0 2753 c.ws.fs = 0 2754 c.ws.frames = nil 2755 c.ws.browser = true 2756 bufs = append(bufs, []byte("some smaller buffer")) 2757 bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers-5)) 2758 bufs = append(bufs, []byte("then some more")) 2759 en = 2 + len(bufs[0]) 2760 en += 4 + len(bufs[1]) 2761 en += 2 + len(bufs[2]) 2762 c.mu.Lock() 2763 c.out.nb = bufs 2764 res, n = c.collapsePtoNB() 2765 c.mu.Unlock() 2766 if test.maskingWrite { 2767 en += 3 * 4 2768 } 2769 if n != int64(en) { 2770 t.Fatalf("Expected size to be %v, got %v", en, n) 2771 } 2772 if len(res) != 6 { 2773 t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) 2774 } 2775 if len(res[3]) != wsFrameSizeForBrowsers-5 { 2776 t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) 2777 } 2778 if test.maskingWrite { 2779 key := getKey(res[4]) 2780 wsMaskBuf(key, res[5]) 2781 } 2782 if string(res[5]) != "then some more" { 2783 t.Fatalf("Frame 6 incorrect %q", res[5]) 2784 } 2785 2786 bufs = nil 2787 c.out.pb = 0 2788 c.ws.fs = 0 2789 c.ws.frames = nil 2790 c.ws.browser = true 2791 bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+100)) 2792 c.mu.Lock() 2793 c.out.nb = bufs 2794 res, _ = c.collapsePtoNB() 2795 c.mu.Unlock() 2796 if len(res) != 4 { 2797 t.Fatalf("Unexpected number of frames: %v", len(res)) 2798 } 2799 }) 2800 } 2801 } 2802 2803 func TestWSWebrowserClient(t *testing.T) { 2804 o := testWSOptions() 2805 s := RunServer(o) 2806 defer s.Shutdown() 2807 2808 wsc, br := testWSCreateClient(t, false, true, o.Websocket.Host, o.Websocket.Port) 2809 defer wsc.Close() 2810 2811 checkClientsCount(t, s, 1) 2812 var c *client 2813 s.mu.Lock() 2814 for _, cli := range s.clients { 2815 c = cli 2816 break 2817 } 2818 s.mu.Unlock() 2819 2820 proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB foo 1\r\nPING\r\n")) 2821 wsc.Write(proto) 2822 if res := testWSReadFrame(t, br); !bytes.Equal(res, []byte(pongProto)) { 2823 t.Fatalf("Expected PONG back") 2824 } 2825 2826 c.mu.Lock() 2827 ok := c.isWebsocket() && c.ws.browser == true 2828 c.mu.Unlock() 2829 if !ok { 2830 t.Fatalf("Client is not marked as webrowser client") 2831 } 2832 2833 nc := natsConnect(t, s.ClientURL()) 2834 defer nc.Close() 2835 2836 // Send a big message and check that it is received in smaller frames 2837 psize := 204813 2838 nc.Publish("foo", make([]byte, psize)) 2839 nc.Flush() 2840 2841 rsize := psize + len(fmt.Sprintf("MSG foo %d\r\n\r\n", psize)) 2842 nframes := 0 2843 for total := 0; total < rsize; nframes++ { 2844 res := testWSReadFrame(t, br) 2845 total += len(res) 2846 } 2847 if expected := psize / wsFrameSizeForBrowsers; expected > nframes { 2848 t.Fatalf("Expected %v frames, got %v", expected, nframes) 2849 } 2850 } 2851 2852 type testWSWrappedConn struct { 2853 net.Conn 2854 mu sync.RWMutex 2855 buf *bytes.Buffer 2856 partial bool 2857 } 2858 2859 func (wc *testWSWrappedConn) Write(p []byte) (int, error) { 2860 wc.mu.Lock() 2861 defer wc.mu.Unlock() 2862 var err error 2863 n := len(p) 2864 if wc.partial && n > 10 { 2865 n = 10 2866 err = io.ErrShortWrite 2867 } 2868 p = p[:n] 2869 wc.buf.Write(p) 2870 wc.Conn.Write(p) 2871 return n, err 2872 } 2873 2874 func TestWSCompressionBasic(t *testing.T) { 2875 payload := "This is the content of a message that will be compresseddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd." 2876 msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload) 2877 cbuf := &bytes.Buffer{} 2878 compressor, err := flate.NewWriter(cbuf, flate.BestSpeed) 2879 require_NoError(t, err) 2880 compressor.Write([]byte(msgProto)) 2881 compressor.Flush() 2882 compressed := cbuf.Bytes() 2883 // The last 4 bytes are dropped 2884 compressed = compressed[:len(compressed)-4] 2885 2886 o := testWSOptions() 2887 o.Websocket.Compression = true 2888 s := RunServer(o) 2889 defer s.Shutdown() 2890 2891 c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port) 2892 defer c.Close() 2893 2894 proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n")) 2895 c.Write(proto) 2896 l := testWSReadFrame(t, br) 2897 if !bytes.Equal(l, []byte(pongProto)) { 2898 t.Fatalf("Expected PONG, got %q", l) 2899 } 2900 2901 var wc *testWSWrappedConn 2902 s.mu.RLock() 2903 for _, c := range s.clients { 2904 c.mu.Lock() 2905 wc = &testWSWrappedConn{Conn: c.nc, buf: &bytes.Buffer{}} 2906 c.nc = wc 2907 c.mu.Unlock() 2908 } 2909 s.mu.RUnlock() 2910 2911 nc := natsConnect(t, s.ClientURL()) 2912 defer nc.Close() 2913 natsPub(t, nc, "foo", []byte(payload)) 2914 2915 res := &bytes.Buffer{} 2916 for total := 0; total < len(msgProto); { 2917 l := testWSReadFrame(t, br) 2918 n, _ := res.Write(l) 2919 total += n 2920 } 2921 if !bytes.Equal([]byte(msgProto), res.Bytes()) { 2922 t.Fatalf("Unexpected result: %q", res) 2923 } 2924 2925 // Now check the wrapped connection buffer to check that data was actually compressed. 2926 wc.mu.RLock() 2927 res = wc.buf 2928 wc.mu.RUnlock() 2929 if bytes.Contains(res.Bytes(), []byte(payload)) { 2930 t.Fatalf("Looks like frame was not compressed: %q", res.Bytes()) 2931 } 2932 header := res.Bytes()[:2] 2933 body := res.Bytes()[2:] 2934 expectedB0 := byte(wsBinaryMessage) | wsFinalBit | wsRsv1Bit 2935 expectedPS := len(compressed) 2936 expectedB1 := byte(expectedPS) 2937 2938 if b := header[0]; b != expectedB0 { 2939 t.Fatalf("Expected first byte to be %v, got %v", expectedB0, b) 2940 } 2941 if b := header[1]; b != expectedB1 { 2942 t.Fatalf("Expected second byte to be %v, got %v", expectedB1, b) 2943 } 2944 if len(body) != expectedPS { 2945 t.Fatalf("Expected payload length to be %v, got %v", expectedPS, len(body)) 2946 } 2947 if !bytes.Equal(body, compressed) { 2948 t.Fatalf("Unexpected compress body: %q", body) 2949 } 2950 2951 wc.mu.Lock() 2952 wc.buf.Reset() 2953 wc.mu.Unlock() 2954 2955 payload = "small" 2956 natsPub(t, nc, "foo", []byte(payload)) 2957 msgProto = fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload) 2958 res = &bytes.Buffer{} 2959 for total := 0; total < len(msgProto); { 2960 l := testWSReadFrame(t, br) 2961 n, _ := res.Write(l) 2962 total += n 2963 } 2964 if !bytes.Equal([]byte(msgProto), res.Bytes()) { 2965 t.Fatalf("Unexpected result: %q", res) 2966 } 2967 wc.mu.RLock() 2968 res = wc.buf 2969 wc.mu.RUnlock() 2970 if !bytes.HasSuffix(res.Bytes(), []byte(msgProto)) { 2971 t.Fatalf("Looks like frame was compressed: %q", res.Bytes()) 2972 } 2973 } 2974 2975 func TestWSCompressionWithPartialWrite(t *testing.T) { 2976 payload := "This is the content of a message that will be compresseddddddddddddddddddddd." 2977 msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload) 2978 2979 o := testWSOptions() 2980 o.Websocket.Compression = true 2981 s := RunServer(o) 2982 defer s.Shutdown() 2983 2984 c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port) 2985 defer c.Close() 2986 2987 proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n")) 2988 c.Write(proto) 2989 l := testWSReadFrame(t, br) 2990 if !bytes.Equal(l, []byte(pongProto)) { 2991 t.Fatalf("Expected PONG, got %q", l) 2992 } 2993 2994 pingPayload := []byte("my ping") 2995 pingFromWSClient := testWSCreateClientMsg(wsPingMessage, 1, true, false, pingPayload) 2996 2997 var wc *testWSWrappedConn 2998 var ws *client 2999 s.mu.Lock() 3000 for _, c := range s.clients { 3001 ws = c 3002 c.mu.Lock() 3003 wc = &testWSWrappedConn{ 3004 Conn: c.nc, 3005 buf: &bytes.Buffer{}, 3006 } 3007 c.nc = wc 3008 c.mu.Unlock() 3009 break 3010 } 3011 s.mu.Unlock() 3012 3013 wc.mu.Lock() 3014 wc.partial = true 3015 wc.mu.Unlock() 3016 3017 nc := natsConnect(t, s.ClientURL()) 3018 defer nc.Close() 3019 3020 expected := &bytes.Buffer{} 3021 for i := 0; i < 10; i++ { 3022 if i > 0 { 3023 time.Sleep(10 * time.Millisecond) 3024 } 3025 expected.Write([]byte(msgProto)) 3026 natsPub(t, nc, "foo", []byte(payload)) 3027 if i == 1 { 3028 c.Write(pingFromWSClient) 3029 } 3030 } 3031 3032 var gotPingResponse bool 3033 res := &bytes.Buffer{} 3034 for total := 0; total < 10*len(msgProto); { 3035 l := testWSReadFrame(t, br) 3036 if bytes.Equal(l, pingPayload) { 3037 gotPingResponse = true 3038 } else { 3039 n, _ := res.Write(l) 3040 total += n 3041 } 3042 } 3043 if !bytes.Equal(expected.Bytes(), res.Bytes()) { 3044 t.Fatalf("Unexpected result: %q", res) 3045 } 3046 if !gotPingResponse { 3047 t.Fatal("Did not get the ping response") 3048 } 3049 3050 checkFor(t, time.Second, 15*time.Millisecond, func() error { 3051 ws.mu.Lock() 3052 pb := ws.out.pb 3053 wf := ws.ws.frames 3054 fs := ws.ws.fs 3055 ws.mu.Unlock() 3056 if pb != 0 || len(wf) != 0 || fs != 0 { 3057 return fmt.Errorf("Expected pb, wf and fs to be 0, got %v, %v, %v", pb, wf, fs) 3058 } 3059 return nil 3060 }) 3061 } 3062 3063 func TestWSCompressionFrameSizeLimit(t *testing.T) { 3064 for _, test := range []struct { 3065 name string 3066 maskWrite bool 3067 noLimit bool 3068 }{ 3069 {"no write masking", false, false}, 3070 {"write masking", true, false}, 3071 } { 3072 t.Run(test.name, func(t *testing.T) { 3073 opts := testWSOptions() 3074 opts.MaxPending = MAX_PENDING_SIZE 3075 s := &Server{opts: opts} 3076 c := &client{srv: s, ws: &websocket{compress: true, browser: true, nocompfrag: test.noLimit, maskwrite: test.maskWrite}} 3077 c.initClient() 3078 3079 uncompressedPayload := make([]byte, 2*wsFrameSizeForBrowsers) 3080 for i := 0; i < len(uncompressedPayload); i++ { 3081 uncompressedPayload[i] = byte(rand.Intn(256)) 3082 } 3083 3084 c.mu.Lock() 3085 c.out.nb = append(net.Buffers(nil), uncompressedPayload) 3086 nb, _ := c.collapsePtoNB() 3087 c.mu.Unlock() 3088 3089 if test.noLimit && len(nb) != 2 { 3090 t.Fatalf("There should be only 2 buffers, the header and payload, got %v", len(nb)) 3091 } 3092 3093 bb := &bytes.Buffer{} 3094 var key []byte 3095 for i, b := range nb { 3096 if !test.noLimit { 3097 // frame header buffer are always very small. The payload should not be more 3098 // than 10 bytes since that is what we passed as the limit. 3099 if len(b) > wsFrameSizeForBrowsers { 3100 t.Fatalf("Frame size too big: %v (%q)", len(b), b) 3101 } 3102 } 3103 if test.maskWrite { 3104 if i%2 == 0 { 3105 key = b[len(b)-4:] 3106 } else { 3107 wsMaskBuf(key, b) 3108 } 3109 } 3110 // Check frame headers for the proper formatting. 3111 if i%2 == 0 { 3112 // Only the first frame should have the compress bit set. 3113 if b[0]&wsRsv1Bit != 0 { 3114 if i > 0 { 3115 t.Fatalf("Compressed bit should not be in continuation frame") 3116 } 3117 } else if i == 0 { 3118 t.Fatalf("Compressed bit missing") 3119 } 3120 } else { 3121 if test.noLimit { 3122 // Since the payload is likely not well compressed, we are expecting 3123 // the length to be > wsFrameSizeForBrowsers 3124 if len(b) <= wsFrameSizeForBrowsers { 3125 t.Fatalf("Expected frame to be bigger, got %v", len(b)) 3126 } 3127 } 3128 // Collect the payload 3129 bb.Write(b) 3130 } 3131 } 3132 buf := bb.Bytes() 3133 buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff) 3134 dbr := bytes.NewBuffer(buf) 3135 d := flate.NewReader(dbr) 3136 uncompressed, err := io.ReadAll(d) 3137 if err != nil { 3138 t.Fatalf("Error reading frame: %v", err) 3139 } 3140 if !bytes.Equal(uncompressed, uncompressedPayload) { 3141 t.Fatalf("Unexpected uncomressed data: %q", uncompressed) 3142 } 3143 }) 3144 } 3145 } 3146 3147 func TestWSBasicAuth(t *testing.T) { 3148 for _, test := range []struct { 3149 name string 3150 opts func() *Options 3151 user string 3152 pass string 3153 err string 3154 cookies []string 3155 }{ 3156 { 3157 "top level auth, no override, wrong u/p", 3158 func() *Options { 3159 o := testWSOptions() 3160 o.Username = "normal" 3161 o.Password = "client" 3162 return o 3163 }, 3164 "websocket", "client", "-ERR 'Authorization Violation'", 3165 nil, 3166 }, 3167 { 3168 "top level auth, no override, correct u/p", 3169 func() *Options { 3170 o := testWSOptions() 3171 o.Username = "normal" 3172 o.Password = "client" 3173 return o 3174 }, 3175 "normal", "client", "", 3176 nil, 3177 }, 3178 { 3179 "no top level auth, ws auth, wrong u/p", 3180 func() *Options { 3181 o := testWSOptions() 3182 o.Websocket.Username = "websocket" 3183 o.Websocket.Password = "client" 3184 return o 3185 }, 3186 "normal", "client", "-ERR 'Authorization Violation'", 3187 nil, 3188 }, 3189 { 3190 "no top level auth, ws auth, correct u/p", 3191 func() *Options { 3192 o := testWSOptions() 3193 o.Websocket.Username = "websocket" 3194 o.Websocket.Password = "client" 3195 return o 3196 }, 3197 "websocket", "client", "", 3198 nil, 3199 }, 3200 { 3201 "top level auth, ws override, wrong u/p", 3202 func() *Options { 3203 o := testWSOptions() 3204 o.Username = "normal" 3205 o.Password = "client" 3206 o.Websocket.Username = "websocket" 3207 o.Websocket.Password = "client" 3208 return o 3209 }, 3210 "normal", "client", "-ERR 'Authorization Violation'", 3211 nil, 3212 }, 3213 { 3214 "top level auth, ws override, correct u/p", 3215 func() *Options { 3216 o := testWSOptions() 3217 o.Username = "normal" 3218 o.Password = "client" 3219 o.Websocket.Username = "websocket" 3220 o.Websocket.Password = "client" 3221 return o 3222 }, 3223 "websocket", "client", "", 3224 nil, 3225 }, 3226 { 3227 "username/password from cookies", 3228 func() *Options { 3229 o := testWSOptions() 3230 o.Websocket.UsernameCookie = "un" 3231 o.Websocket.PasswordCookie = "pw" 3232 o.Username = "me" 3233 o.Password = "s3cr3t!" 3234 return o 3235 }, 3236 "", "", "", 3237 []string{"un=me", "pw=s3cr3t!"}, 3238 }, 3239 { 3240 "bad username/ good password from cookies", 3241 func() *Options { 3242 o := testWSOptions() 3243 o.Websocket.UsernameCookie = "un" 3244 o.Websocket.PasswordCookie = "pw" 3245 o.Username = "me" 3246 o.Password = "s3cr3t!" 3247 return o 3248 }, 3249 "", "", "-ERR 'Authorization Violation", 3250 []string{"un=m", "pw=s3cr3t!"}, 3251 }, 3252 { 3253 "good username/ bad password from cookies", 3254 func() *Options { 3255 o := testWSOptions() 3256 o.Websocket.UsernameCookie = "un" 3257 o.Websocket.PasswordCookie = "pw" 3258 o.Username = "me" 3259 o.Password = "s3cr3t!" 3260 return o 3261 }, 3262 "", "", "-ERR 'Authorization Violation", 3263 []string{"un=me", "pw=hi!"}, 3264 }, 3265 { 3266 "token from cookie", 3267 func() *Options { 3268 o := testWSOptions() 3269 o.Websocket.TokenCookie = "tok" 3270 o.Authorization = "l3tm31n!" 3271 return o 3272 }, 3273 "", "", "", 3274 []string{"tok=l3tm31n!"}, 3275 }, 3276 { 3277 "bad token from cookie", 3278 func() *Options { 3279 o := testWSOptions() 3280 o.Websocket.TokenCookie = "tok" 3281 o.Authorization = "l3tm31n!" 3282 return o 3283 }, 3284 "", "", "-ERR 'Authorization Violation", 3285 []string{"tok=hello!"}, 3286 }, 3287 } { 3288 t.Run(test.name, func(t *testing.T) { 3289 o := test.opts() 3290 s := RunServer(o) 3291 defer s.Shutdown() 3292 3293 wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port, test.cookies...) 3294 defer wsc.Close() 3295 3296 connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", 3297 test.user, test.pass) 3298 3299 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3300 if _, err := wsc.Write(wsmsg); err != nil { 3301 t.Fatalf("Error sending message: %v", err) 3302 } 3303 msg := testWSReadFrame(t, br) 3304 if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3305 t.Fatalf("Expected to receive PONG, got %q", msg) 3306 } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { 3307 t.Fatalf("Expected to receive %q, got %q", test.err, msg) 3308 } 3309 }) 3310 } 3311 } 3312 3313 func TestWSAuthTimeout(t *testing.T) { 3314 for _, test := range []struct { 3315 name string 3316 at float64 3317 wat float64 3318 err string 3319 }{ 3320 {"use top-level auth timeout", 10.0, 0.0, ""}, 3321 {"use websocket auth timeout", 10.0, 0.05, "-ERR 'Authentication Timeout'"}, 3322 } { 3323 t.Run(test.name, func(t *testing.T) { 3324 o := testWSOptions() 3325 o.AuthTimeout = test.at 3326 o.Websocket.Username = "websocket" 3327 o.Websocket.Password = "client" 3328 o.Websocket.AuthTimeout = test.wat 3329 s := RunServer(o) 3330 defer s.Shutdown() 3331 3332 wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) 3333 defer wsc.Close() 3334 3335 var info serverInfo 3336 json.Unmarshal([]byte(l[5:]), &info) 3337 // Make sure that we are told that auth is required. 3338 if !info.AuthRequired { 3339 t.Fatalf("Expected auth required, was not: %q", l) 3340 } 3341 start := time.Now() 3342 // Wait before sending connect 3343 time.Sleep(100 * time.Millisecond) 3344 connectProto := "CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"websocket\",\"pass\":\"client\"}\r\nPING\r\n" 3345 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3346 if _, err := wsc.Write(wsmsg); err != nil { 3347 t.Fatalf("Error sending message: %v", err) 3348 } 3349 msg := testWSReadFrame(t, br) 3350 if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { 3351 t.Fatalf("Expected to receive %q error, got %q", test.err, msg) 3352 } else if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3353 t.Fatalf("Unexpected error: %q", msg) 3354 } 3355 if dur := time.Since(start); dur > time.Second { 3356 t.Fatalf("Too long to get timeout error: %v", dur) 3357 } 3358 }) 3359 } 3360 } 3361 3362 func TestWSTokenAuth(t *testing.T) { 3363 for _, test := range []struct { 3364 name string 3365 opts func() *Options 3366 token string 3367 err string 3368 }{ 3369 { 3370 "top level auth, no override, wrong token", 3371 func() *Options { 3372 o := testWSOptions() 3373 o.Authorization = "goodtoken" 3374 return o 3375 }, 3376 "badtoken", "-ERR 'Authorization Violation'", 3377 }, 3378 { 3379 "top level auth, no override, correct token", 3380 func() *Options { 3381 o := testWSOptions() 3382 o.Authorization = "goodtoken" 3383 return o 3384 }, 3385 "goodtoken", "", 3386 }, 3387 { 3388 "no top level auth, ws auth, wrong token", 3389 func() *Options { 3390 o := testWSOptions() 3391 o.Websocket.Token = "goodtoken" 3392 return o 3393 }, 3394 "badtoken", "-ERR 'Authorization Violation'", 3395 }, 3396 { 3397 "no top level auth, ws auth, correct token", 3398 func() *Options { 3399 o := testWSOptions() 3400 o.Websocket.Token = "goodtoken" 3401 return o 3402 }, 3403 "goodtoken", "", 3404 }, 3405 { 3406 "top level auth, ws override, wrong token", 3407 func() *Options { 3408 o := testWSOptions() 3409 o.Authorization = "clienttoken" 3410 o.Websocket.Token = "websockettoken" 3411 return o 3412 }, 3413 "clienttoken", "-ERR 'Authorization Violation'", 3414 }, 3415 { 3416 "top level auth, ws override, correct token", 3417 func() *Options { 3418 o := testWSOptions() 3419 o.Authorization = "clienttoken" 3420 o.Websocket.Token = "websockettoken" 3421 return o 3422 }, 3423 "websockettoken", "", 3424 }, 3425 } { 3426 t.Run(test.name, func(t *testing.T) { 3427 o := test.opts() 3428 s := RunServer(o) 3429 defer s.Shutdown() 3430 3431 wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) 3432 defer wsc.Close() 3433 3434 connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"auth_token\":\"%s\"}\r\nPING\r\n", 3435 test.token) 3436 3437 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3438 if _, err := wsc.Write(wsmsg); err != nil { 3439 t.Fatalf("Error sending message: %v", err) 3440 } 3441 msg := testWSReadFrame(t, br) 3442 if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3443 t.Fatalf("Expected to receive PONG, got %q", msg) 3444 } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { 3445 t.Fatalf("Expected to receive %q, got %q", test.err, msg) 3446 } 3447 }) 3448 } 3449 } 3450 3451 func TestWSBindToProperAccount(t *testing.T) { 3452 conf := createConfFile(t, []byte(fmt.Sprintf(` 3453 listen: "127.0.0.1:-1" 3454 accounts { 3455 a { 3456 users [ 3457 {user: a, password: pwd, allowed_connection_types: ["%s", "%s"]} 3458 ] 3459 } 3460 b { 3461 users [ 3462 {user: b, password: pwd} 3463 ] 3464 } 3465 } 3466 websocket { 3467 listen: "127.0.0.1:-1" 3468 no_tls: true 3469 } 3470 `, jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)))) // on purpose use lower case to ensure that it is converted. 3471 s, o := RunServerWithConfig(conf) 3472 defer s.Shutdown() 3473 3474 nc := natsConnect(t, fmt.Sprintf("nats://a:pwd@127.0.0.1:%d", o.Port)) 3475 defer nc.Close() 3476 3477 sub := natsSubSync(t, nc, "foo") 3478 3479 wsc, br, _ := testNewWSClient(t, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port, noTLS: true}) 3480 // Send CONNECT and PING 3481 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, 3482 []byte(fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", "a", "pwd"))) 3483 if _, err := wsc.Write(wsmsg); err != nil { 3484 t.Fatalf("Error sending message: %v", err) 3485 } 3486 // Wait for the PONG 3487 if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3488 t.Fatalf("Expected PONG, got %s", msg) 3489 } 3490 3491 wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n")) 3492 if _, err := wsc.Write(wsmsg); err != nil { 3493 t.Fatalf("Error sending message: %v", err) 3494 } 3495 3496 natsNexMsg(t, sub, time.Second) 3497 } 3498 3499 func TestWSUsersAuth(t *testing.T) { 3500 users := []*User{{Username: "user", Password: "pwd"}} 3501 for _, test := range []struct { 3502 name string 3503 opts func() *Options 3504 user string 3505 pass string 3506 err string 3507 }{ 3508 { 3509 "no filtering, wrong user", 3510 func() *Options { 3511 o := testWSOptions() 3512 o.Users = users 3513 return o 3514 }, 3515 "wronguser", "pwd", "-ERR 'Authorization Violation'", 3516 }, 3517 { 3518 "no filtering, correct user", 3519 func() *Options { 3520 o := testWSOptions() 3521 o.Users = users 3522 return o 3523 }, 3524 "user", "pwd", "", 3525 }, 3526 { 3527 "filering, user not allowed", 3528 func() *Options { 3529 o := testWSOptions() 3530 o.Users = users 3531 // Only allowed for regular clients 3532 o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}) 3533 return o 3534 }, 3535 "user", "pwd", "-ERR 'Authorization Violation'", 3536 }, 3537 { 3538 "filtering, user allowed", 3539 func() *Options { 3540 o := testWSOptions() 3541 o.Users = users 3542 o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}) 3543 return o 3544 }, 3545 "user", "pwd", "", 3546 }, 3547 { 3548 "filtering, wrong password", 3549 func() *Options { 3550 o := testWSOptions() 3551 o.Users = users 3552 o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}) 3553 return o 3554 }, 3555 "user", "badpassword", "-ERR 'Authorization Violation'", 3556 }, 3557 } { 3558 t.Run(test.name, func(t *testing.T) { 3559 o := test.opts() 3560 s := RunServer(o) 3561 defer s.Shutdown() 3562 3563 wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) 3564 defer wsc.Close() 3565 3566 connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", 3567 test.user, test.pass) 3568 3569 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3570 if _, err := wsc.Write(wsmsg); err != nil { 3571 t.Fatalf("Error sending message: %v", err) 3572 } 3573 msg := testWSReadFrame(t, br) 3574 if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3575 t.Fatalf("Expected to receive PONG, got %q", msg) 3576 } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { 3577 t.Fatalf("Expected to receive %q, got %q", test.err, msg) 3578 } 3579 }) 3580 } 3581 } 3582 3583 func TestWSNoAuthUserValidation(t *testing.T) { 3584 o := testWSOptions() 3585 o.Users = []*User{{Username: "user", Password: "pwd"}} 3586 // Should fail because it is not part of o.Users. 3587 o.Websocket.NoAuthUser = "notfound" 3588 if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { 3589 t.Fatalf("Expected error saying not present as user, got %v", err) 3590 } 3591 // Set a valid no auth user for global options, but still should fail because 3592 // of o.Websocket.NoAuthUser 3593 o.NoAuthUser = "user" 3594 o.Websocket.NoAuthUser = "notfound" 3595 if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { 3596 t.Fatalf("Expected error saying not present as user, got %v", err) 3597 } 3598 } 3599 3600 func TestWSNoAuthUser(t *testing.T) { 3601 for _, test := range []struct { 3602 name string 3603 override bool 3604 useAuth bool 3605 expectedUser string 3606 expectedAcc string 3607 }{ 3608 {"no override, no user provided", false, false, "noauth", "normal"}, 3609 {"no override, user povided", false, true, "user", "normal"}, 3610 {"override, no user provided", true, false, "wsnoauth", "websocket"}, 3611 {"override, user provided", true, true, "wsuser", "websocket"}, 3612 } { 3613 t.Run(test.name, func(t *testing.T) { 3614 o := testWSOptions() 3615 normalAcc := NewAccount("normal") 3616 websocketAcc := NewAccount("websocket") 3617 o.Accounts = []*Account{normalAcc, websocketAcc} 3618 o.Users = []*User{ 3619 {Username: "noauth", Password: "pwd", Account: normalAcc}, 3620 {Username: "user", Password: "pwd", Account: normalAcc}, 3621 {Username: "wsnoauth", Password: "pwd", Account: websocketAcc}, 3622 {Username: "wsuser", Password: "pwd", Account: websocketAcc}, 3623 } 3624 o.NoAuthUser = "noauth" 3625 if test.override { 3626 o.Websocket.NoAuthUser = "wsnoauth" 3627 } 3628 s := RunServer(o) 3629 defer s.Shutdown() 3630 3631 wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) 3632 defer wsc.Close() 3633 3634 var info serverInfo 3635 json.Unmarshal([]byte(l[5:]), &info) 3636 3637 var connectProto string 3638 if test.useAuth { 3639 connectProto = fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"pwd\"}\r\nPING\r\n", 3640 test.expectedUser) 3641 } else { 3642 connectProto = "CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n" 3643 } 3644 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3645 if _, err := wsc.Write(wsmsg); err != nil { 3646 t.Fatalf("Error sending message: %v", err) 3647 } 3648 msg := testWSReadFrame(t, br) 3649 if !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3650 t.Fatalf("Unexpected error: %q", msg) 3651 } 3652 3653 c := s.getClient(info.CID) 3654 c.mu.Lock() 3655 uname := c.opts.Username 3656 aname := c.acc.GetName() 3657 c.mu.Unlock() 3658 if uname != test.expectedUser { 3659 t.Fatalf("Expected selected user to be %q, got %q", test.expectedUser, uname) 3660 } 3661 if aname != test.expectedAcc { 3662 t.Fatalf("Expected selected account to be %q, got %q", test.expectedAcc, aname) 3663 } 3664 }) 3665 } 3666 } 3667 3668 func TestWSNkeyAuth(t *testing.T) { 3669 nkp, _ := nkeys.CreateUser() 3670 pub, _ := nkp.PublicKey() 3671 3672 wsnkp, _ := nkeys.CreateUser() 3673 wspub, _ := wsnkp.PublicKey() 3674 3675 badkp, _ := nkeys.CreateUser() 3676 badpub, _ := badkp.PublicKey() 3677 3678 for _, test := range []struct { 3679 name string 3680 opts func() *Options 3681 nkey string 3682 kp nkeys.KeyPair 3683 err string 3684 }{ 3685 { 3686 "no filtering, wrong nkey", 3687 func() *Options { 3688 o := testWSOptions() 3689 o.Nkeys = []*NkeyUser{{Nkey: pub}} 3690 return o 3691 }, 3692 badpub, badkp, "-ERR 'Authorization Violation'", 3693 }, 3694 { 3695 "no filtering, correct nkey", 3696 func() *Options { 3697 o := testWSOptions() 3698 o.Nkeys = []*NkeyUser{{Nkey: pub}} 3699 return o 3700 }, 3701 pub, nkp, "", 3702 }, 3703 { 3704 "filtering, nkey not allowed", 3705 func() *Options { 3706 o := testWSOptions() 3707 o.Nkeys = []*NkeyUser{ 3708 { 3709 Nkey: pub, 3710 AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}), 3711 }, 3712 { 3713 Nkey: wspub, 3714 AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeWebsocket}), 3715 }, 3716 } 3717 return o 3718 }, 3719 pub, nkp, "-ERR 'Authorization Violation'", 3720 }, 3721 { 3722 "filtering, correct nkey", 3723 func() *Options { 3724 o := testWSOptions() 3725 o.Nkeys = []*NkeyUser{ 3726 {Nkey: pub}, 3727 { 3728 Nkey: wspub, 3729 AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}), 3730 }, 3731 } 3732 return o 3733 }, 3734 wspub, wsnkp, "", 3735 }, 3736 { 3737 "filtering, wrong nkey", 3738 func() *Options { 3739 o := testWSOptions() 3740 o.Nkeys = []*NkeyUser{ 3741 { 3742 Nkey: wspub, 3743 AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}), 3744 }, 3745 } 3746 return o 3747 }, 3748 badpub, badkp, "-ERR 'Authorization Violation'", 3749 }, 3750 } { 3751 t.Run(test.name, func(t *testing.T) { 3752 o := test.opts() 3753 s := RunServer(o) 3754 defer s.Shutdown() 3755 3756 wsc, br, infoMsg := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) 3757 defer wsc.Close() 3758 3759 // Sign Nonce 3760 var info nonceInfo 3761 json.Unmarshal([]byte(infoMsg[5:]), &info) 3762 sigraw, _ := test.kp.Sign([]byte(info.Nonce)) 3763 sig := base64.RawURLEncoding.EncodeToString(sigraw) 3764 3765 connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"nkey\":\"%s\",\"sig\":\"%s\"}\r\nPING\r\n", test.nkey, sig) 3766 3767 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3768 if _, err := wsc.Write(wsmsg); err != nil { 3769 t.Fatalf("Error sending message: %v", err) 3770 } 3771 msg := testWSReadFrame(t, br) 3772 if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3773 t.Fatalf("Expected to receive PONG, got %q", msg) 3774 } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { 3775 t.Fatalf("Expected to receive %q, got %q", test.err, msg) 3776 } 3777 }) 3778 } 3779 } 3780 3781 func TestWSJWTWithAllowedConnectionTypes(t *testing.T) { 3782 o := testWSOptions() 3783 setupAddTrusted(o) 3784 s := RunServer(o) 3785 buildMemAccResolver(s) 3786 defer s.Shutdown() 3787 3788 for _, test := range []struct { 3789 name string 3790 connectionTypes []string 3791 expectedAnswer string 3792 }{ 3793 {"not allowed", []string{jwt.ConnectionTypeStandard}, "-ERR"}, 3794 {"allowed", []string{jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)}, "+OK"}, 3795 {"allowed with unknown", []string{jwt.ConnectionTypeWebsocket, "SomeNewType"}, "+OK"}, 3796 {"not allowed with unknown", []string{"SomeNewType"}, "-ERR"}, 3797 } { 3798 t.Run(test.name, func(t *testing.T) { 3799 nuc := newJWTTestUserClaims() 3800 nuc.AllowedConnectionTypes = test.connectionTypes 3801 claimOpt := testClaimsOptions{ 3802 nuc: nuc, 3803 expectAnswer: test.expectedAnswer, 3804 } 3805 _, c, _, _ := testWSWithClaims(t, s, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port}, claimOpt) 3806 c.Close() 3807 }) 3808 } 3809 } 3810 3811 func TestWSJWTCookieUser(t *testing.T) { 3812 nucSigFunc := func() *jwt.UserClaims { return newJWTTestUserClaims() } 3813 nucBearerFunc := func() *jwt.UserClaims { 3814 ret := newJWTTestUserClaims() 3815 ret.BearerToken = true 3816 return ret 3817 } 3818 3819 o := testWSOptions() 3820 setupAddTrusted(o) 3821 setupAddCookie(o) 3822 s := RunServer(o) 3823 buildMemAccResolver(s) 3824 defer s.Shutdown() 3825 3826 genJwt := func(t *testing.T, nuc *jwt.UserClaims) string { 3827 okp, _ := nkeys.FromSeed(oSeed) 3828 3829 akp, _ := nkeys.CreateAccount() 3830 apub, _ := akp.PublicKey() 3831 3832 nac := jwt.NewAccountClaims(apub) 3833 ajwt, err := nac.Encode(okp) 3834 if err != nil { 3835 t.Fatalf("Error generating account JWT: %v", err) 3836 } 3837 3838 nkp, _ := nkeys.CreateUser() 3839 pub, _ := nkp.PublicKey() 3840 nuc.Subject = pub 3841 jwt, err := nuc.Encode(akp) 3842 if err != nil { 3843 t.Fatalf("Error generating user JWT: %v", err) 3844 } 3845 addAccountToMemResolver(s, apub, ajwt) 3846 return jwt 3847 } 3848 3849 cliOpts := testWSClientOptions{ 3850 host: o.Websocket.Host, 3851 port: o.Websocket.Port, 3852 } 3853 for _, test := range []struct { 3854 name string 3855 nuc *jwt.UserClaims 3856 opts func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) 3857 expectAnswer string 3858 }{ 3859 { 3860 name: "protocol auth, non-bearer key, with signature", 3861 nuc: nucSigFunc(), 3862 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 3863 return cliOpts, testClaimsOptions{nuc: claims} 3864 }, 3865 expectAnswer: "+OK", 3866 }, 3867 { 3868 name: "protocol auth, non-bearer key, w/o required signature", 3869 nuc: nucSigFunc(), 3870 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 3871 return cliOpts, testClaimsOptions{nuc: claims, dontSign: true} 3872 }, 3873 expectAnswer: "-ERR", 3874 }, 3875 { 3876 name: "protocol auth, bearer key, w/o signature", 3877 nuc: nucBearerFunc(), 3878 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 3879 return cliOpts, testClaimsOptions{nuc: claims, dontSign: true} 3880 }, 3881 expectAnswer: "+OK", 3882 }, 3883 { 3884 name: "cookie auth, non-bearer key, protocol auth fail", 3885 nuc: nucSigFunc(), 3886 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 3887 co := cliOpts 3888 co.extraHeaders = map[string][]string{} 3889 co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)} 3890 return co, testClaimsOptions{connectRequest: struct{}{}} 3891 }, 3892 expectAnswer: "-ERR", 3893 }, 3894 { 3895 name: "cookie auth, bearer key, protocol auth success with implied cookie jwt", 3896 nuc: nucBearerFunc(), 3897 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 3898 co := cliOpts 3899 co.extraHeaders = map[string][]string{} 3900 co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)} 3901 return co, testClaimsOptions{connectRequest: struct{}{}} 3902 }, 3903 expectAnswer: "+OK", 3904 }, 3905 { 3906 name: "cookie auth, non-bearer key, protocol auth success via override jwt in CONNECT opts", 3907 nuc: nucSigFunc(), 3908 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 3909 co := cliOpts 3910 co.extraHeaders = map[string][]string{} 3911 co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)} 3912 return co, testClaimsOptions{nuc: nucBearerFunc()} 3913 }, 3914 expectAnswer: "+OK", 3915 }, 3916 } { 3917 t.Run(test.name, func(t *testing.T) { 3918 cliOpt, claimOpt := test.opts(t, test.nuc) 3919 claimOpt.expectAnswer = test.expectAnswer 3920 _, c, _, _ := testWSWithClaims(t, s, cliOpt, claimOpt) 3921 c.Close() 3922 }) 3923 } 3924 s.Shutdown() 3925 } 3926 3927 func TestWSReloadTLSConfig(t *testing.T) { 3928 template := ` 3929 listen: "127.0.0.1:-1" 3930 websocket { 3931 listen: "127.0.0.1:-1" 3932 tls { 3933 cert_file: '%s' 3934 key_file: '%s' 3935 ca_file: '../test/configs/certs/ca.pem' 3936 } 3937 } 3938 ` 3939 conf := createConfFile(t, []byte(fmt.Sprintf(template, 3940 "../test/configs/certs/server-noip.pem", 3941 "../test/configs/certs/server-key-noip.pem"))) 3942 3943 s, o := RunServerWithConfig(conf) 3944 defer s.Shutdown() 3945 3946 addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port) 3947 wsc, err := net.Dial("tcp", addr) 3948 if err != nil { 3949 t.Fatalf("Error creating ws connection: %v", err) 3950 } 3951 defer wsc.Close() 3952 3953 tc := &TLSConfigOpts{CaFile: "../test/configs/certs/ca.pem"} 3954 tlsConfig, err := GenTLSConfig(tc) 3955 if err != nil { 3956 t.Fatalf("Error generating TLS config: %v", err) 3957 } 3958 tlsConfig.ServerName = "127.0.0.1" 3959 tlsConfig.RootCAs = tlsConfig.ClientCAs 3960 tlsConfig.ClientCAs = nil 3961 wsc = tls.Client(wsc, tlsConfig.Clone()) 3962 if err := wsc.(*tls.Conn).Handshake(); err == nil || !strings.Contains(err.Error(), "SAN") { 3963 t.Fatalf("Unexpected error: %v", err) 3964 } 3965 wsc.Close() 3966 3967 reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, 3968 "../test/configs/certs/server-cert.pem", 3969 "../test/configs/certs/server-key.pem")) 3970 3971 wsc, err = net.Dial("tcp", addr) 3972 if err != nil { 3973 t.Fatalf("Error creating ws connection: %v", err) 3974 } 3975 defer wsc.Close() 3976 3977 wsc = tls.Client(wsc, tlsConfig.Clone()) 3978 if err := wsc.(*tls.Conn).Handshake(); err != nil { 3979 t.Fatalf("Error on TLS handshake: %v", err) 3980 } 3981 } 3982 3983 type captureClientConnectedLogger struct { 3984 DummyLogger 3985 ch chan string 3986 } 3987 3988 func (l *captureClientConnectedLogger) Debugf(format string, v ...interface{}) { 3989 msg := fmt.Sprintf(format, v...) 3990 if !strings.Contains(msg, "Client connection created") { 3991 return 3992 } 3993 select { 3994 case l.ch <- msg: 3995 default: 3996 } 3997 } 3998 3999 func TestWSXForwardedFor(t *testing.T) { 4000 o := testWSOptions() 4001 s := RunServer(o) 4002 defer s.Shutdown() 4003 4004 l := &captureClientConnectedLogger{ch: make(chan string, 1)} 4005 s.SetLogger(l, true, false) 4006 4007 for _, test := range []struct { 4008 name string 4009 headers func() map[string][]string 4010 useHdrValue bool 4011 expectedValue string 4012 }{ 4013 {"nil map", func() map[string][]string { 4014 return nil 4015 }, false, _EMPTY_}, 4016 {"empty map", func() map[string][]string { 4017 return make(map[string][]string) 4018 }, false, _EMPTY_}, 4019 {"header present empty value", func() map[string][]string { 4020 m := make(map[string][]string) 4021 m[wsXForwardedForHeader] = []string{} 4022 return m 4023 }, false, _EMPTY_}, 4024 {"header present invalid IP", func() map[string][]string { 4025 m := make(map[string][]string) 4026 m[wsXForwardedForHeader] = []string{"not a valid IP"} 4027 return m 4028 }, false, _EMPTY_}, 4029 {"header present one IP", func() map[string][]string { 4030 m := make(map[string][]string) 4031 m[wsXForwardedForHeader] = []string{"1.2.3.4"} 4032 return m 4033 }, true, "1.2.3.4"}, 4034 {"header present multiple IPs", func() map[string][]string { 4035 m := make(map[string][]string) 4036 m[wsXForwardedForHeader] = []string{"1.2.3.4", "5.6.7.8"} 4037 return m 4038 }, true, "1.2.3.4"}, 4039 {"header present IPv6", func() map[string][]string { 4040 m := make(map[string][]string) 4041 m[wsXForwardedForHeader] = []string{"::1"} 4042 return m 4043 }, true, "[::1]"}, 4044 } { 4045 t.Run(test.name, func(t *testing.T) { 4046 c, r, _ := testNewWSClient(t, testWSClientOptions{ 4047 host: o.Websocket.Host, 4048 port: o.Websocket.Port, 4049 extraHeaders: test.headers(), 4050 }) 4051 defer c.Close() 4052 // Send CONNECT and PING 4053 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n")) 4054 if _, err := c.Write(wsmsg); err != nil { 4055 t.Fatalf("Error sending message: %v", err) 4056 } 4057 // Wait for the PONG 4058 if msg := testWSReadFrame(t, r); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 4059 t.Fatalf("Expected PONG, got %s", msg) 4060 } 4061 select { 4062 case d := <-l.ch: 4063 ipAndSlash := fmt.Sprintf("%s/", test.expectedValue) 4064 if test.useHdrValue { 4065 if !strings.HasPrefix(d, ipAndSlash) { 4066 t.Fatalf("Expected debug statement to start with: %q, got %q", ipAndSlash, d) 4067 } 4068 } else if strings.HasPrefix(d, ipAndSlash) { 4069 t.Fatalf("Unexpected debug statement: %q", d) 4070 } 4071 case <-time.After(time.Second): 4072 t.Fatal("Did not get connect debug statement") 4073 } 4074 }) 4075 } 4076 } 4077 4078 type partialWriteConn struct { 4079 net.Conn 4080 } 4081 4082 func (c *partialWriteConn) Write(b []byte) (int, error) { 4083 max := len(b) 4084 if max > 0 { 4085 max = rand.Intn(max) 4086 if max == 0 { 4087 max = 1 4088 } 4089 } 4090 n, err := c.Conn.Write(b[:max]) 4091 if err == nil && max != len(b) { 4092 err = io.ErrShortWrite 4093 } 4094 return n, err 4095 } 4096 4097 func TestWSWithPartialWrite(t *testing.T) { 4098 conf := createConfFile(t, []byte(` 4099 listen: "127.0.0.1:-1" 4100 websocket { 4101 listen: "127.0.0.1:-1" 4102 no_tls: true 4103 } 4104 `)) 4105 s, o := RunServerWithConfig(conf) 4106 defer s.Shutdown() 4107 4108 nc1 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o.Websocket.Port)) 4109 defer nc1.Close() 4110 4111 sub := natsSubSync(t, nc1, "foo") 4112 sub.SetPendingLimits(-1, -1) 4113 natsFlush(t, nc1) 4114 4115 nc2 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o.Websocket.Port)) 4116 defer nc2.Close() 4117 4118 // Replace websocket connections with ones that will produce short writes. 4119 s.mu.RLock() 4120 for _, c := range s.clients { 4121 c.mu.Lock() 4122 c.nc = &partialWriteConn{Conn: c.nc} 4123 c.mu.Unlock() 4124 } 4125 s.mu.RUnlock() 4126 4127 var msgs [][]byte 4128 for i := 0; i < 100; i++ { 4129 msg := make([]byte, rand.Intn(10000)+10) 4130 for j := 0; j < len(msg); j++ { 4131 msg[j] = byte('A' + j%26) 4132 } 4133 msgs = append(msgs, msg) 4134 natsPub(t, nc2, "foo", msg) 4135 } 4136 for i := 0; i < 100; i++ { 4137 rmsg := natsNexMsg(t, sub, time.Second) 4138 if !bytes.Equal(msgs[i], rmsg.Data) { 4139 t.Fatalf("Expected message %q, got %q", msgs[i], rmsg.Data) 4140 } 4141 } 4142 } 4143 4144 func testWSNoCorruptionWithFrameSizeLimit(t *testing.T, total int) { 4145 tmpl := ` 4146 listen: "127.0.0.1:-1" 4147 cluster { 4148 name: "local" 4149 port: -1 4150 %s 4151 } 4152 websocket { 4153 listen: "127.0.0.1:-1" 4154 no_tls: true 4155 } 4156 ` 4157 conf1 := createConfFile(t, []byte(fmt.Sprintf(tmpl, _EMPTY_))) 4158 s1, o1 := RunServerWithConfig(conf1) 4159 defer s1.Shutdown() 4160 4161 routes := fmt.Sprintf("routes: [\"nats://127.0.0.1:%d\"]", o1.Cluster.Port) 4162 conf2 := createConfFile(t, []byte(fmt.Sprintf(tmpl, routes))) 4163 s2, o2 := RunServerWithConfig(conf2) 4164 defer s2.Shutdown() 4165 4166 conf3 := createConfFile(t, []byte(fmt.Sprintf(tmpl, routes))) 4167 s3, o3 := RunServerWithConfig(conf3) 4168 defer s3.Shutdown() 4169 4170 checkClusterFormed(t, s1, s2, s3) 4171 4172 nc3 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o3.Websocket.Port)) 4173 defer nc3.Close() 4174 4175 nc2 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o2.Websocket.Port)) 4176 defer nc2.Close() 4177 4178 payload := make([]byte, 100000) 4179 for i := 0; i < len(payload); i++ { 4180 payload[i] = 'A' + byte(i%26) 4181 } 4182 errCh := make(chan error, 1) 4183 doneCh := make(chan struct{}, 1) 4184 count := int32(0) 4185 4186 createSub := func(nc *nats.Conn) { 4187 sub := natsSub(t, nc, "foo", func(m *nats.Msg) { 4188 if !bytes.Equal(m.Data, payload) { 4189 stop := len(m.Data) 4190 if l := len(payload); l < stop { 4191 stop = l 4192 } 4193 start := 0 4194 for i := 0; i < stop; i++ { 4195 if m.Data[i] != payload[i] { 4196 start = i 4197 break 4198 } 4199 } 4200 if stop-start > 20 { 4201 stop = start + 20 4202 } 4203 select { 4204 case errCh <- fmt.Errorf("Invalid message: [%d bytes same]%s[...]", start, m.Data[start:stop]): 4205 default: 4206 } 4207 return 4208 } 4209 if n := atomic.AddInt32(&count, 1); int(n) == 2*total { 4210 doneCh <- struct{}{} 4211 } 4212 }) 4213 sub.SetPendingLimits(-1, -1) 4214 } 4215 createSub(nc2) 4216 createSub(nc3) 4217 4218 checkSubInterest(t, s1, globalAccountName, "foo", time.Second) 4219 4220 nc1 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o1.Websocket.Port)) 4221 defer nc1.Close() 4222 natsFlush(t, nc1) 4223 4224 // Change websocket connections to force a max frame size. 4225 for _, s := range []*Server{s1, s2, s3} { 4226 s.mu.RLock() 4227 for _, c := range s.clients { 4228 c.mu.Lock() 4229 if c.ws != nil { 4230 c.ws.browser = true 4231 } 4232 c.mu.Unlock() 4233 } 4234 s.mu.RUnlock() 4235 } 4236 4237 for i := 0; i < total; i++ { 4238 natsPub(t, nc1, "foo", payload) 4239 if i%100 == 0 { 4240 select { 4241 case err := <-errCh: 4242 t.Fatalf("Error: %v", err) 4243 default: 4244 } 4245 } 4246 } 4247 select { 4248 case err := <-errCh: 4249 t.Fatalf("Error: %v", err) 4250 case <-doneCh: 4251 return 4252 case <-time.After(10 * time.Second): 4253 t.Fatalf("Test timed out") 4254 } 4255 } 4256 4257 func TestWSNoCorruptionWithFrameSizeLimit(t *testing.T) { 4258 testWSNoCorruptionWithFrameSizeLimit(t, 1000) 4259 } 4260 4261 // ================================================================== 4262 // = Benchmark tests 4263 // ================================================================== 4264 4265 const testWSBenchSubject = "a" 4266 4267 var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()") 4268 4269 func sizedString(sz int) string { 4270 b := make([]byte, sz) 4271 for i := range b { 4272 b[i] = ch[rand.Intn(len(ch))] 4273 } 4274 return string(b) 4275 } 4276 4277 func sizedStringForCompression(sz int) string { 4278 b := make([]byte, sz) 4279 c := byte(0) 4280 s := 0 4281 for i := range b { 4282 if s%20 == 0 { 4283 c = ch[rand.Intn(len(ch))] 4284 } 4285 b[i] = c 4286 } 4287 return string(b) 4288 } 4289 4290 func testWSFlushConn(b *testing.B, compress bool, c net.Conn, br *bufio.Reader) { 4291 buf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte(pingProto)) 4292 c.Write(buf) 4293 c.SetReadDeadline(time.Now().Add(5 * time.Second)) 4294 res := testWSReadFrame(b, br) 4295 c.SetReadDeadline(time.Time{}) 4296 if !bytes.HasPrefix(res, []byte(pongProto)) { 4297 b.Fatalf("Failed read of PONG: %s\n", res) 4298 } 4299 } 4300 4301 func wsBenchPub(b *testing.B, numPubs int, compress bool, payload string) { 4302 b.StopTimer() 4303 opts := testWSOptions() 4304 opts.Websocket.Compression = compress 4305 s := RunServer(opts) 4306 defer s.Shutdown() 4307 4308 extra := 0 4309 pubProto := []byte(fmt.Sprintf("PUB %s %d\r\n%s\r\n", testWSBenchSubject, len(payload), payload)) 4310 singleOpBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, pubProto) 4311 4312 // Simulate client that would buffer messages before framing/sending. 4313 // Figure out how many we can fit in one frame based on b.N and length of pubProto 4314 const bufSize = 32768 4315 tmpa := [bufSize]byte{} 4316 tmp := tmpa[:0] 4317 pb := 0 4318 for i := 0; i < b.N; i++ { 4319 tmp = append(tmp, pubProto...) 4320 pb++ 4321 if len(tmp) >= bufSize { 4322 break 4323 } 4324 } 4325 sendBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, tmp) 4326 n := b.N / pb 4327 extra = b.N - (n * pb) 4328 4329 wg := sync.WaitGroup{} 4330 wg.Add(numPubs) 4331 4332 type pub struct { 4333 c net.Conn 4334 br *bufio.Reader 4335 bw *bufio.Writer 4336 } 4337 var pubs []pub 4338 for i := 0; i < numPubs; i++ { 4339 wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port) 4340 defer wsc.Close() 4341 bw := bufio.NewWriterSize(wsc, bufSize) 4342 pubs = append(pubs, pub{wsc, br, bw}) 4343 } 4344 4345 // Average the amount of bytes sent by iteration 4346 avg := len(sendBuf) / pb 4347 if extra > 0 { 4348 avg += len(singleOpBuf) 4349 avg /= 2 4350 } 4351 b.SetBytes(int64(numPubs * avg)) 4352 b.StartTimer() 4353 4354 for i := 0; i < numPubs; i++ { 4355 p := pubs[i] 4356 go func(p pub) { 4357 defer wg.Done() 4358 for i := 0; i < n; i++ { 4359 p.bw.Write(sendBuf) 4360 } 4361 for i := 0; i < extra; i++ { 4362 p.bw.Write(singleOpBuf) 4363 } 4364 p.bw.Flush() 4365 testWSFlushConn(b, compress, p.c, p.br) 4366 }(p) 4367 } 4368 wg.Wait() 4369 b.StopTimer() 4370 } 4371 4372 func Benchmark_WS_Pubx1_CN_____0b(b *testing.B) { 4373 wsBenchPub(b, 1, false, "") 4374 } 4375 4376 func Benchmark_WS_Pubx1_CY_____0b(b *testing.B) { 4377 wsBenchPub(b, 1, true, "") 4378 } 4379 4380 func Benchmark_WS_Pubx1_CN___128b(b *testing.B) { 4381 s := sizedString(128) 4382 wsBenchPub(b, 1, false, s) 4383 } 4384 4385 func Benchmark_WS_Pubx1_CY___128b(b *testing.B) { 4386 s := sizedStringForCompression(128) 4387 wsBenchPub(b, 1, true, s) 4388 } 4389 4390 func Benchmark_WS_Pubx1_CN__1024b(b *testing.B) { 4391 s := sizedString(1024) 4392 wsBenchPub(b, 1, false, s) 4393 } 4394 4395 func Benchmark_WS_Pubx1_CY__1024b(b *testing.B) { 4396 s := sizedStringForCompression(1024) 4397 wsBenchPub(b, 1, true, s) 4398 } 4399 4400 func Benchmark_WS_Pubx1_CN__4096b(b *testing.B) { 4401 s := sizedString(4 * 1024) 4402 wsBenchPub(b, 1, false, s) 4403 } 4404 4405 func Benchmark_WS_Pubx1_CY__4096b(b *testing.B) { 4406 s := sizedStringForCompression(4 * 1024) 4407 wsBenchPub(b, 1, true, s) 4408 } 4409 4410 func Benchmark_WS_Pubx1_CN__8192b(b *testing.B) { 4411 s := sizedString(8 * 1024) 4412 wsBenchPub(b, 1, false, s) 4413 } 4414 4415 func Benchmark_WS_Pubx1_CY__8192b(b *testing.B) { 4416 s := sizedStringForCompression(8 * 1024) 4417 wsBenchPub(b, 1, true, s) 4418 } 4419 4420 func Benchmark_WS_Pubx1_CN_32768b(b *testing.B) { 4421 s := sizedString(32 * 1024) 4422 wsBenchPub(b, 1, false, s) 4423 } 4424 4425 func Benchmark_WS_Pubx1_CY_32768b(b *testing.B) { 4426 s := sizedStringForCompression(32 * 1024) 4427 wsBenchPub(b, 1, true, s) 4428 } 4429 4430 func Benchmark_WS_Pubx5_CN_____0b(b *testing.B) { 4431 wsBenchPub(b, 5, false, "") 4432 } 4433 4434 func Benchmark_WS_Pubx5_CY_____0b(b *testing.B) { 4435 wsBenchPub(b, 5, true, "") 4436 } 4437 4438 func Benchmark_WS_Pubx5_CN___128b(b *testing.B) { 4439 s := sizedString(128) 4440 wsBenchPub(b, 5, false, s) 4441 } 4442 4443 func Benchmark_WS_Pubx5_CY___128b(b *testing.B) { 4444 s := sizedStringForCompression(128) 4445 wsBenchPub(b, 5, true, s) 4446 } 4447 4448 func Benchmark_WS_Pubx5_CN__1024b(b *testing.B) { 4449 s := sizedString(1024) 4450 wsBenchPub(b, 5, false, s) 4451 } 4452 4453 func Benchmark_WS_Pubx5_CY__1024b(b *testing.B) { 4454 s := sizedStringForCompression(1024) 4455 wsBenchPub(b, 5, true, s) 4456 } 4457 4458 func Benchmark_WS_Pubx5_CN__4096b(b *testing.B) { 4459 s := sizedString(4 * 1024) 4460 wsBenchPub(b, 5, false, s) 4461 } 4462 4463 func Benchmark_WS_Pubx5_CY__4096b(b *testing.B) { 4464 s := sizedStringForCompression(4 * 1024) 4465 wsBenchPub(b, 5, true, s) 4466 } 4467 4468 func Benchmark_WS_Pubx5_CN__8192b(b *testing.B) { 4469 s := sizedString(8 * 1024) 4470 wsBenchPub(b, 5, false, s) 4471 } 4472 4473 func Benchmark_WS_Pubx5_CY__8192b(b *testing.B) { 4474 s := sizedStringForCompression(8 * 1024) 4475 wsBenchPub(b, 5, true, s) 4476 } 4477 4478 func Benchmark_WS_Pubx5_CN_32768b(b *testing.B) { 4479 s := sizedString(32 * 1024) 4480 wsBenchPub(b, 5, false, s) 4481 } 4482 4483 func Benchmark_WS_Pubx5_CY_32768b(b *testing.B) { 4484 s := sizedStringForCompression(32 * 1024) 4485 wsBenchPub(b, 5, true, s) 4486 } 4487 4488 func wsBenchSub(b *testing.B, numSubs int, compress bool, payload string) { 4489 b.StopTimer() 4490 opts := testWSOptions() 4491 opts.Websocket.Compression = compress 4492 s := RunServer(opts) 4493 defer s.Shutdown() 4494 4495 var subs []*bufio.Reader 4496 for i := 0; i < numSubs; i++ { 4497 wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port) 4498 defer wsc.Close() 4499 subProto := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, 4500 []byte(fmt.Sprintf("SUB %s 1\r\nPING\r\n", testWSBenchSubject))) 4501 wsc.Write(subProto) 4502 // Waiting for PONG 4503 testWSReadFrame(b, br) 4504 subs = append(subs, br) 4505 } 4506 4507 wg := sync.WaitGroup{} 4508 wg.Add(numSubs) 4509 4510 // Use regular NATS client to publish messages 4511 nc := natsConnect(b, s.ClientURL()) 4512 defer nc.Close() 4513 4514 b.StartTimer() 4515 4516 for i := 0; i < numSubs; i++ { 4517 br := subs[i] 4518 go func(br *bufio.Reader) { 4519 defer wg.Done() 4520 for count := 0; count < b.N; { 4521 msgs := testWSReadFrame(b, br) 4522 count += bytes.Count(msgs, []byte("MSG ")) 4523 } 4524 }(br) 4525 } 4526 for i := 0; i < b.N; i++ { 4527 natsPub(b, nc, testWSBenchSubject, []byte(payload)) 4528 } 4529 wg.Wait() 4530 b.StopTimer() 4531 } 4532 4533 func Benchmark_WS_Subx1_CN_____0b(b *testing.B) { 4534 wsBenchSub(b, 1, false, "") 4535 } 4536 4537 func Benchmark_WS_Subx1_CY_____0b(b *testing.B) { 4538 wsBenchSub(b, 1, true, "") 4539 } 4540 4541 func Benchmark_WS_Subx1_CN___128b(b *testing.B) { 4542 s := sizedString(128) 4543 wsBenchSub(b, 1, false, s) 4544 } 4545 4546 func Benchmark_WS_Subx1_CY___128b(b *testing.B) { 4547 s := sizedStringForCompression(128) 4548 wsBenchSub(b, 1, true, s) 4549 } 4550 4551 func Benchmark_WS_Subx1_CN__1024b(b *testing.B) { 4552 s := sizedString(1024) 4553 wsBenchSub(b, 1, false, s) 4554 } 4555 4556 func Benchmark_WS_Subx1_CY__1024b(b *testing.B) { 4557 s := sizedStringForCompression(1024) 4558 wsBenchSub(b, 1, true, s) 4559 } 4560 4561 func Benchmark_WS_Subx1_CN__4096b(b *testing.B) { 4562 s := sizedString(4096) 4563 wsBenchSub(b, 1, false, s) 4564 } 4565 4566 func Benchmark_WS_Subx1_CY__4096b(b *testing.B) { 4567 s := sizedStringForCompression(4096) 4568 wsBenchSub(b, 1, true, s) 4569 } 4570 4571 func Benchmark_WS_Subx1_CN__8192b(b *testing.B) { 4572 s := sizedString(8192) 4573 wsBenchSub(b, 1, false, s) 4574 } 4575 4576 func Benchmark_WS_Subx1_CY__8192b(b *testing.B) { 4577 s := sizedStringForCompression(8192) 4578 wsBenchSub(b, 1, true, s) 4579 } 4580 4581 func Benchmark_WS_Subx1_CN_32768b(b *testing.B) { 4582 s := sizedString(32768) 4583 wsBenchSub(b, 1, false, s) 4584 } 4585 4586 func Benchmark_WS_Subx1_CY_32768b(b *testing.B) { 4587 s := sizedStringForCompression(32768) 4588 wsBenchSub(b, 1, true, s) 4589 } 4590 4591 func Benchmark_WS_Subx5_CN_____0b(b *testing.B) { 4592 wsBenchSub(b, 5, false, "") 4593 } 4594 4595 func Benchmark_WS_Subx5_CY_____0b(b *testing.B) { 4596 wsBenchSub(b, 5, true, "") 4597 } 4598 4599 func Benchmark_WS_Subx5_CN___128b(b *testing.B) { 4600 s := sizedString(128) 4601 wsBenchSub(b, 5, false, s) 4602 } 4603 4604 func Benchmark_WS_Subx5_CY___128b(b *testing.B) { 4605 s := sizedStringForCompression(128) 4606 wsBenchSub(b, 5, true, s) 4607 } 4608 4609 func Benchmark_WS_Subx5_CN__1024b(b *testing.B) { 4610 s := sizedString(1024) 4611 wsBenchSub(b, 5, false, s) 4612 } 4613 4614 func Benchmark_WS_Subx5_CY__1024b(b *testing.B) { 4615 s := sizedStringForCompression(1024) 4616 wsBenchSub(b, 5, true, s) 4617 } 4618 4619 func Benchmark_WS_Subx5_CN__4096b(b *testing.B) { 4620 s := sizedString(4096) 4621 wsBenchSub(b, 5, false, s) 4622 } 4623 4624 func Benchmark_WS_Subx5_CY__4096b(b *testing.B) { 4625 s := sizedStringForCompression(4096) 4626 wsBenchSub(b, 5, true, s) 4627 } 4628 4629 func Benchmark_WS_Subx5_CN__8192b(b *testing.B) { 4630 s := sizedString(8192) 4631 wsBenchSub(b, 5, false, s) 4632 } 4633 4634 func Benchmark_WS_Subx5_CY__8192b(b *testing.B) { 4635 s := sizedStringForCompression(8192) 4636 wsBenchSub(b, 5, true, s) 4637 } 4638 4639 func Benchmark_WS_Subx5_CN_32768b(b *testing.B) { 4640 s := sizedString(32768) 4641 wsBenchSub(b, 5, false, s) 4642 } 4643 4644 func Benchmark_WS_Subx5_CY_32768b(b *testing.B) { 4645 s := sizedStringForCompression(32768) 4646 wsBenchSub(b, 5, true, s) 4647 }