github.com/nats-io/nats-server/v2@v2.11.0-preview.2/server/websocket_test.go (about) 1 // Copyright 2020-2024 The NATS Authors 2 // Licensed under the Apache License, Version 2.0 (the "License"); 3 // you may not use this file except in compliance with the License. 4 // You may obtain a copy of the License at 5 // 6 // http://www.apache.org/licenses/LICENSE-2.0 7 // 8 // Unless required by applicable law or agreed to in writing, software 9 // distributed under the License is distributed on an "AS IS" BASIS, 10 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 11 // See the License for the specific language governing permissions and 12 // limitations under the License. 13 14 package server 15 16 import ( 17 "bufio" 18 "bytes" 19 "crypto/tls" 20 "encoding/base64" 21 "encoding/binary" 22 "encoding/json" 23 "errors" 24 "fmt" 25 "io" 26 "math/rand" 27 "net" 28 "net/http" 29 "net/url" 30 "reflect" 31 "strings" 32 "sync" 33 "sync/atomic" 34 "testing" 35 "time" 36 37 "github.com/nats-io/jwt/v2" 38 "github.com/nats-io/nats.go" 39 "github.com/nats-io/nkeys" 40 41 "github.com/klauspost/compress/flate" 42 ) 43 44 type testReader struct { 45 buf []byte 46 pos int 47 max int 48 err error 49 } 50 51 func (tr *testReader) Read(p []byte) (int, error) { 52 if tr.err != nil { 53 return 0, tr.err 54 } 55 n := len(tr.buf) - tr.pos 56 if n == 0 { 57 return 0, nil 58 } 59 if n > len(p) { 60 n = len(p) 61 } 62 if tr.max > 0 && n > tr.max { 63 n = tr.max 64 } 65 copy(p, tr.buf[tr.pos:tr.pos+n]) 66 tr.pos += n 67 return n, nil 68 } 69 70 func TestWSGet(t *testing.T) { 71 rb := []byte("012345") 72 73 tr := &testReader{buf: []byte("6789")} 74 75 for _, test := range []struct { 76 name string 77 pos int 78 needed int 79 newpos int 80 trmax int 81 result string 82 reterr bool 83 }{ 84 {"fromrb1", 0, 3, 3, 4, "012", false}, // Partial from read buffer 85 {"fromrb2", 3, 2, 5, 4, "34", false}, // Partial from read buffer 86 {"fromrb3", 5, 1, 6, 4, "5", false}, // Partial from read buffer 87 {"fromtr1", 4, 4, 6, 4, "4567", false}, // Partial from read buffer + some of ioReader 88 {"fromtr2", 4, 6, 6, 4, "456789", false}, // Partial from read buffer + all of ioReader 89 {"fromtr3", 4, 6, 6, 2, "456789", false}, // Partial from read buffer + all of ioReader with several reads 90 {"fromtr4", 4, 6, 6, 2, "", true}, // ioReader returns error 91 } { 92 t.Run(test.name, func(t *testing.T) { 93 tr.pos = 0 94 tr.max = test.trmax 95 if test.reterr { 96 tr.err = fmt.Errorf("on purpose") 97 } 98 res, np, err := wsGet(tr, rb, test.pos, test.needed) 99 if test.reterr { 100 if err == nil { 101 t.Fatalf("Expected error, got none") 102 } 103 if err.Error() != "on purpose" { 104 t.Fatalf("Unexpected error: %v", err) 105 } 106 if np != 0 || res != nil { 107 t.Fatalf("Unexpected returned values: res=%v n=%v", res, np) 108 } 109 return 110 } 111 if err != nil { 112 t.Fatalf("Error on get: %v", err) 113 } 114 if np != test.newpos { 115 t.Fatalf("Expected pos=%v, got %v", test.newpos, np) 116 } 117 if string(res) != test.result { 118 t.Fatalf("Invalid returned content: %s", res) 119 } 120 }) 121 } 122 } 123 124 func TestWSIsControlFrame(t *testing.T) { 125 for _, test := range []struct { 126 name string 127 code wsOpCode 128 isControl bool 129 }{ 130 {"binary", wsBinaryMessage, false}, 131 {"text", wsTextMessage, false}, 132 {"ping", wsPingMessage, true}, 133 {"pong", wsPongMessage, true}, 134 {"close", wsCloseMessage, true}, 135 } { 136 t.Run(test.name, func(t *testing.T) { 137 if res := wsIsControlFrame(test.code); res != test.isControl { 138 t.Fatalf("Expected %q isControl to be %v, got %v", test.name, test.isControl, res) 139 } 140 }) 141 } 142 } 143 144 func testWSSimpleMask(key, buf []byte) { 145 for i := 0; i < len(buf); i++ { 146 buf[i] ^= key[i&3] 147 } 148 } 149 150 func TestWSUnmask(t *testing.T) { 151 key := []byte{1, 2, 3, 4} 152 orgBuf := []byte("this is a clear text") 153 154 mask := func() []byte { 155 t.Helper() 156 buf := append([]byte(nil), orgBuf...) 157 testWSSimpleMask(key, buf) 158 // First ensure that the content is masked. 159 if bytes.Equal(buf, orgBuf) { 160 t.Fatalf("Masking did not do anything: %q", buf) 161 } 162 return buf 163 } 164 165 ri := &wsReadInfo{mask: true} 166 ri.init() 167 copy(ri.mkey[:], key) 168 169 buf := mask() 170 // Unmask in one call 171 ri.unmask(buf) 172 if !bytes.Equal(buf, orgBuf) { 173 t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf) 174 } 175 176 // Unmask in multiple calls 177 buf = mask() 178 ri.mkpos = 0 179 ri.unmask(buf[:3]) 180 ri.unmask(buf[3:11]) 181 ri.unmask(buf[11:]) 182 if !bytes.Equal(buf, orgBuf) { 183 t.Fatalf("Unmask error, expected %q, got %q", orgBuf, buf) 184 } 185 } 186 187 func TestWSCreateCloseMessage(t *testing.T) { 188 for _, test := range []struct { 189 name string 190 status int 191 psize int 192 truncated bool 193 }{ 194 {"fits", wsCloseStatusInternalSrvError, 10, false}, 195 {"truncated", wsCloseStatusProtocolError, wsMaxControlPayloadSize + 10, true}, 196 } { 197 t.Run(test.name, func(t *testing.T) { 198 payload := make([]byte, test.psize) 199 for i := 0; i < len(payload); i++ { 200 payload[i] = byte('A' + (i % 26)) 201 } 202 res := wsCreateCloseMessage(test.status, string(payload)) 203 if status := binary.BigEndian.Uint16(res[:2]); int(status) != test.status { 204 t.Fatalf("Expected status to be %v, got %v", test.status, status) 205 } 206 psize := len(res) - 2 207 if !test.truncated { 208 if int(psize) != test.psize { 209 t.Fatalf("Expected size to be %v, got %v", test.psize, psize) 210 } 211 if !bytes.Equal(res[2:], payload) { 212 t.Fatalf("Unexpected result: %q", res[2:]) 213 } 214 return 215 } 216 // Since the payload of a close message contains a 2 byte status, the 217 // actual max text size will be wsMaxControlPayloadSize-2 218 if int(psize) != wsMaxControlPayloadSize-2 { 219 t.Fatalf("Expected size to be capped to %v, got %v", wsMaxControlPayloadSize-2, psize) 220 } 221 if string(res[len(res)-3:]) != "..." { 222 t.Fatalf("Expected res to have `...` at the end, got %q", res[4:]) 223 } 224 }) 225 } 226 } 227 228 func TestWSCreateFrameHeader(t *testing.T) { 229 for _, test := range []struct { 230 name string 231 frameType wsOpCode 232 compressed bool 233 len int 234 }{ 235 {"uncompressed 10", wsBinaryMessage, false, 10}, 236 {"uncompressed 600", wsTextMessage, false, 600}, 237 {"uncompressed 100000", wsTextMessage, false, 100000}, 238 {"compressed 10", wsBinaryMessage, true, 10}, 239 {"compressed 600", wsBinaryMessage, true, 600}, 240 {"compressed 100000", wsTextMessage, true, 100000}, 241 } { 242 t.Run(test.name, func(t *testing.T) { 243 res, _ := wsCreateFrameHeader(false, test.compressed, test.frameType, test.len) 244 // The server is always sending the message has a single frame, 245 // so the "final" bit should be set. 246 expected := byte(test.frameType) | wsFinalBit 247 if test.compressed { 248 expected |= wsRsv1Bit 249 } 250 if b := res[0]; b != expected { 251 t.Fatalf("Expected first byte to be %v, got %v", expected, b) 252 } 253 switch { 254 case test.len <= 125: 255 if len(res) != 2 { 256 t.Fatalf("Frame len should be 2, got %v", len(res)) 257 } 258 if res[1] != byte(test.len) { 259 t.Fatalf("Expected len to be in second byte and be %v, got %v", test.len, res[1]) 260 } 261 case test.len < 65536: 262 // 1+1+2 263 if len(res) != 4 { 264 t.Fatalf("Frame len should be 4, got %v", len(res)) 265 } 266 if res[1] != 126 { 267 t.Fatalf("Second byte value should be 126, got %v", res[1]) 268 } 269 if rl := binary.BigEndian.Uint16(res[2:]); int(rl) != test.len { 270 t.Fatalf("Expected len to be %v, got %v", test.len, rl) 271 } 272 default: 273 // 1+1+8 274 if len(res) != 10 { 275 t.Fatalf("Frame len should be 10, got %v", len(res)) 276 } 277 if res[1] != 127 { 278 t.Fatalf("Second byte value should be 127, got %v", res[1]) 279 } 280 if rl := binary.BigEndian.Uint64(res[2:]); int(rl) != test.len { 281 t.Fatalf("Expected len to be %v, got %v", test.len, rl) 282 } 283 } 284 }) 285 } 286 } 287 288 func testWSCreateClientMsg(frameType wsOpCode, frameNum int, final, compressed bool, payload []byte) []byte { 289 if compressed { 290 buf := &bytes.Buffer{} 291 compressor, _ := flate.NewWriter(buf, 1) 292 compressor.Write(payload) 293 compressor.Flush() 294 payload = buf.Bytes() 295 // The last 4 bytes are dropped 296 payload = payload[:len(payload)-4] 297 } 298 frame := make([]byte, 14+len(payload)) 299 if frameNum == 1 { 300 frame[0] = byte(frameType) 301 } 302 if final { 303 frame[0] |= wsFinalBit 304 } 305 if compressed { 306 frame[0] |= wsRsv1Bit 307 } 308 pos := 1 309 lenPayload := len(payload) 310 switch { 311 case lenPayload <= 125: 312 frame[pos] = byte(lenPayload) | wsMaskBit 313 pos++ 314 case lenPayload < 65536: 315 frame[pos] = 126 | wsMaskBit 316 binary.BigEndian.PutUint16(frame[2:], uint16(lenPayload)) 317 pos += 3 318 default: 319 frame[1] = 127 | wsMaskBit 320 binary.BigEndian.PutUint64(frame[2:], uint64(lenPayload)) 321 pos += 9 322 } 323 key := []byte{1, 2, 3, 4} 324 copy(frame[pos:], key) 325 pos += 4 326 copy(frame[pos:], payload) 327 testWSSimpleMask(key, frame[pos:]) 328 pos += lenPayload 329 return frame[:pos] 330 } 331 332 func testWSSetupForRead() (*client, *wsReadInfo, *testReader) { 333 ri := &wsReadInfo{mask: true} 334 ri.init() 335 tr := &testReader{} 336 opts := DefaultOptions() 337 opts.MaxPending = MAX_PENDING_SIZE 338 s := &Server{opts: opts} 339 c := &client{srv: s, ws: &websocket{}} 340 c.initClient() 341 return c, ri, tr 342 } 343 344 func TestWSReadUncompressedFrames(t *testing.T) { 345 c, ri, tr := testWSSetupForRead() 346 // Create 2 WS messages 347 pl1 := []byte("first message") 348 wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl1) 349 pl2 := []byte("second message") 350 wsmsg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, pl2) 351 // Add both in single buffer 352 orgrb := append([]byte(nil), wsmsg1...) 353 orgrb = append(orgrb, wsmsg2...) 354 355 rb := append([]byte(nil), orgrb...) 356 bufs, err := c.wsRead(ri, tr, rb) 357 if err != nil { 358 t.Fatalf("Unexpected error: %v", err) 359 } 360 if n := len(bufs); n != 2 { 361 t.Fatalf("Expected 2 buffers, got %v", n) 362 } 363 if !bytes.Equal(bufs[0], pl1) { 364 t.Fatalf("Unexpected content for buffer 1: %s", bufs[0]) 365 } 366 if !bytes.Equal(bufs[1], pl2) { 367 t.Fatalf("Unexpected content for buffer 2: %s", bufs[1]) 368 } 369 370 // Now reset and try with the read buffer not containing full ws frame 371 c, ri, tr = testWSSetupForRead() 372 rb = append([]byte(nil), orgrb...) 373 // Frame is 1+1+4+'first message'. So say we pass with rb of 11 bytes, 374 // then we should get "first" 375 bufs, err = c.wsRead(ri, tr, rb[:11]) 376 if err != nil { 377 t.Fatalf("Unexpected error: %v", err) 378 } 379 if n := len(bufs); n != 1 { 380 t.Fatalf("Unexpected buffer returned: %v", n) 381 } 382 if string(bufs[0]) != "first" { 383 t.Fatalf("Unexpected content: %q", bufs[0]) 384 } 385 // Call again with more data.. 386 bufs, err = c.wsRead(ri, tr, rb[11:32]) 387 if err != nil { 388 t.Fatalf("Unexpected error: %v", err) 389 } 390 if n := len(bufs); n != 2 { 391 t.Fatalf("Unexpected buffer returned: %v", n) 392 } 393 if string(bufs[0]) != " message" { 394 t.Fatalf("Unexpected content: %q", bufs[0]) 395 } 396 if string(bufs[1]) != "second " { 397 t.Fatalf("Unexpected content: %q", bufs[1]) 398 } 399 // Call with the rest 400 bufs, err = c.wsRead(ri, tr, rb[32:]) 401 if err != nil { 402 t.Fatalf("Unexpected error: %v", err) 403 } 404 if n := len(bufs); n != 1 { 405 t.Fatalf("Unexpected buffer returned: %v", n) 406 } 407 if string(bufs[0]) != "message" { 408 t.Fatalf("Unexpected content: %q", bufs[0]) 409 } 410 } 411 412 func TestWSReadCompressedFrames(t *testing.T) { 413 c, ri, tr := testWSSetupForRead() 414 uncompressed := []byte("this is the uncompress data") 415 wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed) 416 rb := append([]byte(nil), wsmsg1...) 417 // Call with some but not all of the payload 418 bufs, err := c.wsRead(ri, tr, rb[:10]) 419 if err != nil { 420 t.Fatalf("Unexpected error: %v", err) 421 } 422 if n := len(bufs); n != 0 { 423 t.Fatalf("Unexpected buffer returned: %v", n) 424 } 425 // Call with the rest, only then should we get the uncompressed data. 426 bufs, err = c.wsRead(ri, tr, rb[10:]) 427 if err != nil { 428 t.Fatalf("Unexpected error: %v", err) 429 } 430 if n := len(bufs); n != 1 { 431 t.Fatalf("Unexpected buffer returned: %v", n) 432 } 433 if !bytes.Equal(bufs[0], uncompressed) { 434 t.Fatalf("Unexpected content: %s", bufs[0]) 435 } 436 // Stress the fact that we use a pool and want to make sure 437 // that if we get a decompressor from the pool, it is properly reset 438 // with the buffer to decompress. 439 // Since we unmask the read buffer, reset it now and fill it 440 // with 10 compressed frames. 441 rb = nil 442 for i := 0; i < 10; i++ { 443 rb = append(rb, wsmsg1...) 444 } 445 bufs, err = c.wsRead(ri, tr, rb) 446 if err != nil { 447 t.Fatalf("Unexpected error: %v", err) 448 } 449 if n := len(bufs); n != 10 { 450 t.Fatalf("Unexpected buffer returned: %v", n) 451 } 452 453 // Compress a message and send it in several frames. 454 buf := &bytes.Buffer{} 455 compressor, _ := flate.NewWriter(buf, 1) 456 compressor.Write(uncompressed) 457 compressor.Flush() 458 compressed := buf.Bytes() 459 // The last 4 bytes are dropped 460 compressed = compressed[:len(compressed)-4] 461 ncomp := 10 462 frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, compressed[:ncomp]) 463 frag1[0] |= wsRsv1Bit 464 frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, compressed[ncomp:]) 465 rb = append([]byte(nil), frag1...) 466 rb = append(rb, frag2...) 467 bufs, err = c.wsRead(ri, tr, rb) 468 if err != nil { 469 t.Fatalf("Unexpected error: %v", err) 470 } 471 if n := len(bufs); n != 1 { 472 t.Fatalf("Unexpected buffer returned: %v", n) 473 } 474 if !bytes.Equal(bufs[0], uncompressed) { 475 t.Fatalf("Unexpected content: %s", bufs[0]) 476 } 477 } 478 479 func TestWSReadCompressedFrameCorrupted(t *testing.T) { 480 c, ri, tr := testWSSetupForRead() 481 uncompressed := []byte("this is the uncompress data") 482 wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, uncompressed) 483 copy(wsmsg1[10:], []byte{1, 2, 3, 4}) 484 rb := append([]byte(nil), wsmsg1...) 485 bufs, err := c.wsRead(ri, tr, rb) 486 if err == nil || !strings.Contains(err.Error(), "corrupt") { 487 t.Fatalf("Expected error about corrupted data, got %v", err) 488 } 489 if n := len(bufs); n != 0 { 490 t.Fatalf("Expected no buffer, got %v", n) 491 } 492 } 493 494 func TestWSReadVariousFrameSizes(t *testing.T) { 495 for _, test := range []struct { 496 name string 497 size int 498 }{ 499 {"tiny", 100}, 500 {"medium", 1000}, 501 {"large", 70000}, 502 } { 503 t.Run(test.name, func(t *testing.T) { 504 c, ri, tr := testWSSetupForRead() 505 uncompressed := make([]byte, test.size) 506 for i := 0; i < len(uncompressed); i++ { 507 uncompressed[i] = 'A' + byte(i%26) 508 } 509 wsmsg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, uncompressed) 510 rb := append([]byte(nil), wsmsg1...) 511 bufs, err := c.wsRead(ri, tr, rb) 512 if err != nil { 513 t.Fatalf("Unexpected error: %v", err) 514 } 515 if n := len(bufs); n != 1 { 516 t.Fatalf("Unexpected buffer returned: %v", n) 517 } 518 if !bytes.Equal(bufs[0], uncompressed) { 519 t.Fatalf("Unexpected content: %s", bufs[0]) 520 } 521 }) 522 } 523 } 524 525 func TestWSReadFragmentedFrames(t *testing.T) { 526 c, ri, tr := testWSSetupForRead() 527 payloads := []string{"first", "second", "third"} 528 var rb []byte 529 for i := 0; i < len(payloads); i++ { 530 final := i == len(payloads)-1 531 frag := testWSCreateClientMsg(wsBinaryMessage, i+1, final, false, []byte(payloads[i])) 532 rb = append(rb, frag...) 533 } 534 bufs, err := c.wsRead(ri, tr, rb) 535 if err != nil { 536 t.Fatalf("Unexpected error: %v", err) 537 } 538 if n := len(bufs); n != 3 { 539 t.Fatalf("Unexpected buffer returned: %v", n) 540 } 541 for i, expected := range payloads { 542 if string(bufs[i]) != expected { 543 t.Fatalf("Unexpected content for buf=%v: %s", i, bufs[i]) 544 } 545 } 546 } 547 548 func TestWSReadPartialFrameHeaderAtEndOfReadBuffer(t *testing.T) { 549 c, ri, tr := testWSSetupForRead() 550 msg1 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg1")) 551 msg2 := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg2")) 552 rb := append([]byte(nil), msg1...) 553 rb = append(rb, msg2...) 554 // We will pass the first frame + the first byte of the next frame. 555 rbl := rb[:len(msg1)+1] 556 // Make the io reader return the rest of the frame 557 tr.buf = rb[len(msg1)+1:] 558 bufs, err := c.wsRead(ri, tr, rbl) 559 if err != nil { 560 t.Fatalf("Unexpected error: %v", err) 561 } 562 if n := len(bufs); n != 1 { 563 t.Fatalf("Unexpected buffer returned: %v", n) 564 } 565 // We should not have asked to the io reader more than what is needed for reading 566 // the frame header. Since we had already the first byte in the read buffer, 567 // tr.pos should be 1(size)+4(key)=5 568 if tr.pos != 5 { 569 t.Fatalf("Expected reader pos to be 5, got %v", tr.pos) 570 } 571 } 572 573 func TestWSReadPingFrame(t *testing.T) { 574 for _, test := range []struct { 575 name string 576 payload []byte 577 }{ 578 {"without payload", nil}, 579 {"with payload", []byte("optional payload")}, 580 } { 581 t.Run(test.name, func(t *testing.T) { 582 c, ri, tr := testWSSetupForRead() 583 ping := testWSCreateClientMsg(wsPingMessage, 1, true, false, test.payload) 584 rb := append([]byte(nil), ping...) 585 bufs, err := c.wsRead(ri, tr, rb) 586 if err != nil { 587 t.Fatalf("Unexpected error: %v", err) 588 } 589 if n := len(bufs); n != 0 { 590 t.Fatalf("Unexpected buffer returned: %v", n) 591 } 592 // A PONG should have been queued with the payload of the ping 593 c.mu.Lock() 594 nb, _ := c.collapsePtoNB() 595 c.mu.Unlock() 596 if n := len(nb); n == 0 { 597 t.Fatalf("Expected buffers, got %v", n) 598 } 599 if expected := 2 + len(test.payload); expected != len(nb[0]) { 600 t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) 601 } 602 b := nb[0][0] 603 if b&wsFinalBit == 0 { 604 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 605 } 606 if b&byte(wsPongMessage) == 0 { 607 t.Fatalf("Should have been a PONG, it wasn't: %v", b) 608 } 609 if len(test.payload) > 0 { 610 if !bytes.Equal(nb[0][2:], test.payload) { 611 t.Fatalf("Unexpected content: %s", nb[0][2:]) 612 } 613 } 614 }) 615 } 616 } 617 618 func TestWSReadPongFrame(t *testing.T) { 619 for _, test := range []struct { 620 name string 621 payload []byte 622 }{ 623 {"without payload", nil}, 624 {"with payload", []byte("optional payload")}, 625 } { 626 t.Run(test.name, func(t *testing.T) { 627 c, ri, tr := testWSSetupForRead() 628 pong := testWSCreateClientMsg(wsPongMessage, 1, true, false, test.payload) 629 rb := append([]byte(nil), pong...) 630 bufs, err := c.wsRead(ri, tr, rb) 631 if err != nil { 632 t.Fatalf("Unexpected error: %v", err) 633 } 634 if n := len(bufs); n != 0 { 635 t.Fatalf("Unexpected buffer returned: %v", n) 636 } 637 // Nothing should be sent... 638 c.mu.Lock() 639 nb, _ := c.collapsePtoNB() 640 c.mu.Unlock() 641 if n := len(nb); n != 0 { 642 t.Fatalf("Expected no buffer, got %v", n) 643 } 644 }) 645 } 646 } 647 648 func TestWSReadCloseFrame(t *testing.T) { 649 for _, test := range []struct { 650 name string 651 payload []byte 652 }{ 653 {"without payload", nil}, 654 {"with payload", []byte("optional payload")}, 655 } { 656 t.Run(test.name, func(t *testing.T) { 657 c, ri, tr := testWSSetupForRead() 658 // a close message has a status in 2 bytes + optional payload 659 payload := make([]byte, 2+len(test.payload)) 660 binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure) 661 if len(test.payload) > 0 { 662 copy(payload[2:], test.payload) 663 } 664 close := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload) 665 // Have a normal frame prior to close to make sure that wsRead returns 666 // the normal frame along with io.EOF to indicate that wsCloseMessage was received. 667 msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg")) 668 rb := append([]byte(nil), msg...) 669 rb = append(rb, close...) 670 bufs, err := c.wsRead(ri, tr, rb) 671 // It is expected that wsRead returns io.EOF on processing a close. 672 if err != io.EOF { 673 t.Fatalf("Unexpected error: %v", err) 674 } 675 if n := len(bufs); n != 1 { 676 t.Fatalf("Unexpected buffer returned: %v", n) 677 } 678 if string(bufs[0]) != "msg" { 679 t.Fatalf("Unexpected content: %s", bufs[0]) 680 } 681 // A CLOSE should have been queued with the payload of the original close message. 682 c.mu.Lock() 683 nb, _ := c.collapsePtoNB() 684 c.mu.Unlock() 685 if n := len(nb); n == 0 { 686 t.Fatalf("Expected buffers, got %v", n) 687 } 688 if expected := 2 + 2 + len(test.payload); expected != len(nb[0]) { 689 t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) 690 } 691 b := nb[0][0] 692 if b&wsFinalBit == 0 { 693 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 694 } 695 if b&byte(wsCloseMessage) == 0 { 696 t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) 697 } 698 if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure { 699 t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status) 700 } 701 if len(test.payload) > 0 { 702 if !bytes.Equal(nb[0][4:], test.payload) { 703 t.Fatalf("Unexpected content: %s", nb[0][4:]) 704 } 705 } 706 }) 707 } 708 } 709 710 func TestWSReadControlFrameBetweebFragmentedFrames(t *testing.T) { 711 c, ri, tr := testWSSetupForRead() 712 frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("first")) 713 frag2 := testWSCreateClientMsg(wsBinaryMessage, 2, true, false, []byte("second")) 714 ctrl := testWSCreateClientMsg(wsPongMessage, 1, true, false, nil) 715 rb := append([]byte(nil), frag1...) 716 rb = append(rb, ctrl...) 717 rb = append(rb, frag2...) 718 bufs, err := c.wsRead(ri, tr, rb) 719 if err != nil { 720 t.Fatalf("Unexpected error: %v", err) 721 } 722 if n := len(bufs); n != 2 { 723 t.Fatalf("Unexpected buffer returned: %v", n) 724 } 725 if string(bufs[0]) != "first" { 726 t.Fatalf("Unexpected content: %s", bufs[0]) 727 } 728 if string(bufs[1]) != "second" { 729 t.Fatalf("Unexpected content: %s", bufs[1]) 730 } 731 } 732 733 func TestWSCloseFrameWithPartialOrInvalid(t *testing.T) { 734 c, ri, tr := testWSSetupForRead() 735 // a close message has a status in 2 bytes + optional payload 736 payloadTxt := []byte("hello") 737 payload := make([]byte, 2+len(payloadTxt)) 738 binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure) 739 copy(payload[2:], payloadTxt) 740 closeMsg := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload) 741 742 // We will pass to wsRead a buffer of small capacity that contains 743 // only 1 byte. 744 closeFirtByte := []byte{closeMsg[0]} 745 // Make the io reader return the rest of the frame 746 tr.buf = closeMsg[1:] 747 bufs, err := c.wsRead(ri, tr, closeFirtByte[:]) 748 // It is expected that wsRead returns io.EOF on processing a close. 749 if err != io.EOF { 750 t.Fatalf("Unexpected error: %v", err) 751 } 752 if n := len(bufs); n != 0 { 753 t.Fatalf("Unexpected buffer returned: %v", n) 754 } 755 // A CLOSE should have been queued with the payload of the original close message. 756 c.mu.Lock() 757 nb, _ := c.collapsePtoNB() 758 c.mu.Unlock() 759 if n := len(nb); n == 0 { 760 t.Fatalf("Expected buffers, got %v", n) 761 } 762 if expected := 2 + 2 + len(payloadTxt); expected != len(nb[0]) { 763 t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) 764 } 765 b := nb[0][0] 766 if b&wsFinalBit == 0 { 767 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 768 } 769 if b&byte(wsCloseMessage) == 0 { 770 t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) 771 } 772 if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNormalClosure { 773 t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNormalClosure, status) 774 } 775 if !bytes.Equal(nb[0][4:], payloadTxt) { 776 t.Fatalf("Unexpected content: %s", nb[0][4:]) 777 } 778 779 // Now test close with invalid status size (1 instead of 2 bytes) 780 c, ri, tr = testWSSetupForRead() 781 payload[0] = 100 782 binary.BigEndian.PutUint16(payload, wsCloseStatusNormalClosure) 783 closeMsg = testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload[:1]) 784 785 // We will pass to wsRead a buffer of small capacity that contains 786 // only 1 byte. 787 closeFirtByte = []byte{closeMsg[0]} 788 // Make the io reader return the rest of the frame 789 tr.buf = closeMsg[1:] 790 bufs, err = c.wsRead(ri, tr, closeFirtByte[:]) 791 // It is expected that wsRead returns io.EOF on processing a close. 792 if err != io.EOF { 793 t.Fatalf("Unexpected error: %v", err) 794 } 795 if n := len(bufs); n != 0 { 796 t.Fatalf("Unexpected buffer returned: %v", n) 797 } 798 // A CLOSE should have been queued with the payload of the original close message. 799 c.mu.Lock() 800 nb, _ = c.collapsePtoNB() 801 c.mu.Unlock() 802 if n := len(nb); n == 0 { 803 t.Fatalf("Expected buffers, got %v", n) 804 } 805 if expected := 2 + 2; expected != len(nb[0]) { 806 t.Fatalf("Expected buffer to be %v bytes long, got %v", expected, len(nb[0])) 807 } 808 b = nb[0][0] 809 if b&wsFinalBit == 0 { 810 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 811 } 812 if b&byte(wsCloseMessage) == 0 { 813 t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) 814 } 815 // Since satus was not valid, we should get wsCloseStatusNoStatusReceived 816 if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusNoStatusReceived { 817 t.Fatalf("Expected status to be %v, got %v", wsCloseStatusNoStatusReceived, status) 818 } 819 if len(nb[0][:]) != 4 { 820 t.Fatalf("Unexpected content: %s", nb[0][2:]) 821 } 822 } 823 824 func TestWSReadGetErrors(t *testing.T) { 825 tr := &testReader{err: fmt.Errorf("on purpose")} 826 for _, test := range []struct { 827 lenPayload int 828 rbextra int 829 }{ 830 {10, 1}, 831 {10, 3}, 832 {200, 1}, 833 {200, 2}, 834 {200, 5}, 835 {70000, 1}, 836 {70000, 5}, 837 {70000, 13}, 838 } { 839 t.Run("", func(t *testing.T) { 840 c, ri, _ := testWSSetupForRead() 841 msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("msg")) 842 frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, make([]byte, test.lenPayload)) 843 rb := append([]byte(nil), msg...) 844 rb = append(rb, frame...) 845 bufs, err := c.wsRead(ri, tr, rb[:len(msg)+test.rbextra]) 846 if err == nil || err.Error() != "on purpose" { 847 t.Fatalf("Expected 'on purpose' error, got %v", err) 848 } 849 if n := len(bufs); n != 1 { 850 t.Fatalf("Unexpected buffer returned: %v", n) 851 } 852 if string(bufs[0]) != "msg" { 853 t.Fatalf("Unexpected content: %s", bufs[0]) 854 } 855 }) 856 } 857 } 858 859 func TestWSHandleControlFrameErrors(t *testing.T) { 860 c, ri, tr := testWSSetupForRead() 861 tr.err = fmt.Errorf("on purpose") 862 863 // a close message has a status in 2 bytes + optional payload 864 text := []byte("this is a close message") 865 payload := make([]byte, 2+len(text)) 866 binary.BigEndian.PutUint16(payload[:2], wsCloseStatusNormalClosure) 867 copy(payload[2:], text) 868 ctrl := testWSCreateClientMsg(wsCloseMessage, 1, true, false, payload) 869 870 bufs, err := c.wsRead(ri, tr, ctrl[:len(ctrl)-4]) 871 if err == nil || err.Error() != "on purpose" { 872 t.Fatalf("Expected 'on purpose' error, got %v", err) 873 } 874 if n := len(bufs); n != 0 { 875 t.Fatalf("Unexpected buffer returned: %v", n) 876 } 877 878 // Alter the content of close message. It is supposed to be valid utf-8. 879 c, ri, tr = testWSSetupForRead() 880 cp := append([]byte(nil), payload...) 881 cp[10] = 0xF1 882 ctrl = testWSCreateClientMsg(wsCloseMessage, 1, true, false, cp) 883 bufs, err = c.wsRead(ri, tr, ctrl) 884 // We should still receive an EOF but the message enqueued to the client 885 // should contain wsCloseStatusInvalidPayloadData and the error about invalid utf8 886 if err != io.EOF { 887 t.Fatalf("Unexpected error: %v", err) 888 } 889 if n := len(bufs); n != 0 { 890 t.Fatalf("Unexpected buffer returned: %v", n) 891 } 892 c.mu.Lock() 893 nb, _ := c.collapsePtoNB() 894 c.mu.Unlock() 895 if n := len(nb); n == 0 { 896 t.Fatalf("Expected buffers, got %v", n) 897 } 898 b := nb[0][0] 899 if b&wsFinalBit == 0 { 900 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 901 } 902 if b&byte(wsCloseMessage) == 0 { 903 t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) 904 } 905 if status := binary.BigEndian.Uint16(nb[0][2:4]); status != wsCloseStatusInvalidPayloadData { 906 t.Fatalf("Expected status to be %v, got %v", wsCloseStatusInvalidPayloadData, status) 907 } 908 if !bytes.Contains(nb[0][4:], []byte("utf8")) { 909 t.Fatalf("Unexpected content: %s", nb[0][4:]) 910 } 911 } 912 913 func TestWSReadErrors(t *testing.T) { 914 for _, test := range []struct { 915 cframe func() []byte 916 err string 917 nbufs int 918 }{ 919 { 920 func() []byte { 921 msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello")) 922 msg[1] &= ^byte(wsMaskBit) 923 return msg 924 }, 925 "mask bit missing", 1, 926 }, 927 { 928 func() []byte { 929 return testWSCreateClientMsg(wsPingMessage, 1, true, false, make([]byte, 200)) 930 }, 931 "control frame length bigger than maximum allowed", 1, 932 }, 933 { 934 func() []byte { 935 return testWSCreateClientMsg(wsPingMessage, 1, false, false, []byte("hello")) 936 }, 937 "control frame does not have final bit set", 1, 938 }, 939 { 940 func() []byte { 941 frag1 := testWSCreateClientMsg(wsBinaryMessage, 1, false, false, []byte("frag1")) 942 newMsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("new message")) 943 all := append([]byte(nil), frag1...) 944 all = append(all, newMsg...) 945 return all 946 }, 947 "new message started before final frame for previous message was received", 2, 948 }, 949 { 950 func() []byte { 951 frame := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("frame")) 952 frag := testWSCreateClientMsg(wsBinaryMessage, 2, false, false, []byte("continuation")) 953 all := append([]byte(nil), frame...) 954 all = append(all, frag...) 955 return all 956 }, 957 "invalid continuation frame", 2, 958 }, 959 { 960 func() []byte { 961 return testWSCreateClientMsg(wsBinaryMessage, 2, false, true, []byte("frame")) 962 }, 963 "invalid continuation frame", 1, 964 }, 965 { 966 func() []byte { 967 return testWSCreateClientMsg(99, 1, false, false, []byte("hello")) 968 }, 969 "unknown opcode", 1, 970 }, 971 } { 972 t.Run(test.err, func(t *testing.T) { 973 c, ri, tr := testWSSetupForRead() 974 // Add a valid message first 975 msg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("hello")) 976 // Then add the bad frame 977 bad := test.cframe() 978 // Add them both to a read buffer 979 rb := append([]byte(nil), msg...) 980 rb = append(rb, bad...) 981 bufs, err := c.wsRead(ri, tr, rb) 982 if err == nil || !strings.Contains(err.Error(), test.err) { 983 t.Fatalf("Expected error to contain %q, got %q", test.err, err.Error()) 984 } 985 if n := len(bufs); n != test.nbufs { 986 t.Fatalf("Unexpected number of buffers: %v", n) 987 } 988 if string(bufs[0]) != "hello" { 989 t.Fatalf("Unexpected content: %s", bufs[0]) 990 } 991 }) 992 } 993 } 994 995 func TestWSEnqueueCloseMsg(t *testing.T) { 996 for _, test := range []struct { 997 reason ClosedState 998 status int 999 }{ 1000 {ClientClosed, wsCloseStatusNormalClosure}, 1001 {AuthenticationTimeout, wsCloseStatusPolicyViolation}, 1002 {AuthenticationViolation, wsCloseStatusPolicyViolation}, 1003 {SlowConsumerPendingBytes, wsCloseStatusPolicyViolation}, 1004 {SlowConsumerWriteDeadline, wsCloseStatusPolicyViolation}, 1005 {MaxAccountConnectionsExceeded, wsCloseStatusPolicyViolation}, 1006 {MaxConnectionsExceeded, wsCloseStatusPolicyViolation}, 1007 {MaxControlLineExceeded, wsCloseStatusPolicyViolation}, 1008 {MaxSubscriptionsExceeded, wsCloseStatusPolicyViolation}, 1009 {MissingAccount, wsCloseStatusPolicyViolation}, 1010 {AuthenticationExpired, wsCloseStatusPolicyViolation}, 1011 {Revocation, wsCloseStatusPolicyViolation}, 1012 {TLSHandshakeError, wsCloseStatusTLSHandshake}, 1013 {ParseError, wsCloseStatusProtocolError}, 1014 {ProtocolViolation, wsCloseStatusProtocolError}, 1015 {BadClientProtocolVersion, wsCloseStatusProtocolError}, 1016 {MaxPayloadExceeded, wsCloseStatusMessageTooBig}, 1017 {ServerShutdown, wsCloseStatusGoingAway}, 1018 {WriteError, wsCloseStatusAbnormalClosure}, 1019 {ReadError, wsCloseStatusAbnormalClosure}, 1020 {StaleConnection, wsCloseStatusAbnormalClosure}, 1021 {ClosedState(254), wsCloseStatusInternalSrvError}, 1022 } { 1023 t.Run(test.reason.String(), func(t *testing.T) { 1024 c, _, _ := testWSSetupForRead() 1025 c.wsEnqueueCloseMessage(test.reason) 1026 c.mu.Lock() 1027 nb, _ := c.collapsePtoNB() 1028 c.mu.Unlock() 1029 if n := len(nb); n != 1 { 1030 t.Fatalf("Expected 1 buffer, got %v", n) 1031 } 1032 b := nb[0][0] 1033 if b&wsFinalBit == 0 { 1034 t.Fatalf("Control frame should have been the final flag, it was not set: %v", b) 1035 } 1036 if b&byte(wsCloseMessage) == 0 { 1037 t.Fatalf("Should have been a CLOSE, it wasn't: %v", b) 1038 } 1039 if status := binary.BigEndian.Uint16(nb[0][2:4]); int(status) != test.status { 1040 t.Fatalf("Expected status to be %v, got %v", test.status, status) 1041 } 1042 if string(nb[0][4:]) != test.reason.String() { 1043 t.Fatalf("Unexpected content: %s", nb[0][4:]) 1044 } 1045 }) 1046 } 1047 } 1048 1049 type testResponseWriter struct { 1050 http.ResponseWriter 1051 buf bytes.Buffer 1052 headers http.Header 1053 err error 1054 brw *bufio.ReadWriter 1055 conn *testWSFakeNetConn 1056 } 1057 1058 func (trw *testResponseWriter) Write(p []byte) (int, error) { 1059 return trw.buf.Write(p) 1060 } 1061 1062 func (trw *testResponseWriter) WriteHeader(status int) { 1063 trw.buf.WriteString(fmt.Sprintf("%v", status)) 1064 } 1065 1066 func (trw *testResponseWriter) Header() http.Header { 1067 if trw.headers == nil { 1068 trw.headers = make(http.Header) 1069 } 1070 return trw.headers 1071 } 1072 1073 type testWSFakeNetConn struct { 1074 net.Conn 1075 wbuf bytes.Buffer 1076 err error 1077 wsOpened bool 1078 isClosed bool 1079 deadlineCleared bool 1080 } 1081 1082 func (c *testWSFakeNetConn) Write(p []byte) (int, error) { 1083 if c.err != nil { 1084 return 0, c.err 1085 } 1086 return c.wbuf.Write(p) 1087 } 1088 1089 func (c *testWSFakeNetConn) SetDeadline(t time.Time) error { 1090 if t.IsZero() { 1091 c.deadlineCleared = true 1092 } 1093 return nil 1094 } 1095 1096 func (c *testWSFakeNetConn) Close() error { 1097 c.isClosed = true 1098 return nil 1099 } 1100 1101 func (trw *testResponseWriter) Hijack() (net.Conn, *bufio.ReadWriter, error) { 1102 if trw.conn == nil { 1103 trw.conn = &testWSFakeNetConn{} 1104 } 1105 trw.conn.wsOpened = true 1106 if trw.brw == nil { 1107 trw.brw = bufio.NewReadWriter(bufio.NewReader(trw.conn), bufio.NewWriter(trw.conn)) 1108 } 1109 return trw.conn, trw.brw, trw.err 1110 } 1111 1112 func testWSOptions() *Options { 1113 opts := DefaultOptions() 1114 opts.DisableShortFirstPing = true 1115 opts.Websocket.Host = "127.0.0.1" 1116 opts.Websocket.Port = -1 1117 opts.NoSystemAccount = true 1118 var err error 1119 tc := &TLSConfigOpts{ 1120 CertFile: "./configs/certs/server.pem", 1121 KeyFile: "./configs/certs/key.pem", 1122 } 1123 opts.Websocket.TLSConfig, err = GenTLSConfig(tc) 1124 if err != nil { 1125 panic(err) 1126 } 1127 return opts 1128 } 1129 1130 func testWSCreateValidReq() *http.Request { 1131 req := &http.Request{ 1132 Method: "GET", 1133 Host: "localhost", 1134 Proto: "HTTP/1.1", 1135 } 1136 req.Header = make(http.Header) 1137 req.Header.Set("Upgrade", "websocket") 1138 req.Header.Set("Connection", "Upgrade") 1139 req.Header.Set("Sec-Websocket-Key", "dGhlIHNhbXBsZSBub25jZQ==") 1140 req.Header.Set("Sec-Websocket-Version", "13") 1141 return req 1142 } 1143 1144 func TestWSCheckOrigin(t *testing.T) { 1145 notSameOrigin := false 1146 sameOrigin := true 1147 allowedListEmpty := []string{} 1148 someList := []string{"http://host1.com", "http://host2.com:1234"} 1149 1150 for _, test := range []struct { 1151 name string 1152 sameOrigin bool 1153 origins []string 1154 reqHost string 1155 reqTLS bool 1156 origin string 1157 err string 1158 }{ 1159 {"any", notSameOrigin, allowedListEmpty, "", false, "http://any.host.com", ""}, 1160 {"same origin ok", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:80", ""}, 1161 {"same origin bad host", sameOrigin, allowedListEmpty, "host.com", false, "http://other.host.com", "not same origin"}, 1162 {"same origin bad port", sameOrigin, allowedListEmpty, "host.com", false, "http://host.com:81", "not same origin"}, 1163 {"same origin bad scheme", sameOrigin, allowedListEmpty, "host.com", true, "http://host.com", "not same origin"}, 1164 {"same origin bad uri", sameOrigin, allowedListEmpty, "host.com", false, "@@@://invalid:url:1234", "invalid URI"}, 1165 {"same origin bad url", sameOrigin, allowedListEmpty, "host.com", false, "http://invalid:url:1234", "too many colons"}, 1166 {"same origin bad req host", sameOrigin, allowedListEmpty, "invalid:url:1234", false, "http://host.com", "too many colons"}, 1167 {"no origin same origin ignored", sameOrigin, allowedListEmpty, "", false, "", ""}, 1168 {"no origin list ignored", sameOrigin, someList, "", false, "", ""}, 1169 {"no origin same origin and list ignored", sameOrigin, someList, "", false, "", ""}, 1170 {"allowed from list", notSameOrigin, someList, "", false, "http://host2.com:1234", ""}, 1171 {"allowed with different path", notSameOrigin, someList, "", false, "http://host1.com/some/path", ""}, 1172 {"list bad port", notSameOrigin, someList, "", false, "http://host1.com:1234", "not in the allowed list"}, 1173 {"list bad scheme", notSameOrigin, someList, "", false, "https://host2.com:1234", "not in the allowed list"}, 1174 } { 1175 t.Run(test.name, func(t *testing.T) { 1176 opts := DefaultOptions() 1177 opts.Websocket.SameOrigin = test.sameOrigin 1178 opts.Websocket.AllowedOrigins = test.origins 1179 s := &Server{opts: opts} 1180 s.wsSetOriginOptions(&opts.Websocket) 1181 1182 req := testWSCreateValidReq() 1183 req.Host = test.reqHost 1184 if test.reqTLS { 1185 req.TLS = &tls.ConnectionState{} 1186 } 1187 if test.origin != "" { 1188 req.Header.Set("Origin", test.origin) 1189 } 1190 err := s.websocket.checkOrigin(req) 1191 if test.err == "" && err != nil { 1192 t.Fatalf("Unexpected error: %v", err) 1193 } else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) { 1194 t.Fatalf("Expected error %q, got %v", test.err, err) 1195 } 1196 }) 1197 } 1198 } 1199 1200 func TestWSUpgradeValidationErrors(t *testing.T) { 1201 for _, test := range []struct { 1202 name string 1203 setup func() (*Options, *testResponseWriter, *http.Request) 1204 err string 1205 status int 1206 }{ 1207 { 1208 "bad method", 1209 func() (*Options, *testResponseWriter, *http.Request) { 1210 opts := testWSOptions() 1211 req := testWSCreateValidReq() 1212 req.Method = "POST" 1213 return opts, nil, req 1214 }, 1215 "must be GET", 1216 http.StatusMethodNotAllowed, 1217 }, 1218 { 1219 "no host", 1220 func() (*Options, *testResponseWriter, *http.Request) { 1221 opts := testWSOptions() 1222 req := testWSCreateValidReq() 1223 req.Host = "" 1224 return opts, nil, req 1225 }, 1226 "'Host' missing in request", 1227 http.StatusBadRequest, 1228 }, 1229 { 1230 "invalid upgrade header", 1231 func() (*Options, *testResponseWriter, *http.Request) { 1232 opts := testWSOptions() 1233 req := testWSCreateValidReq() 1234 req.Header.Del("Upgrade") 1235 return opts, nil, req 1236 }, 1237 "invalid value for header 'Upgrade'", 1238 http.StatusBadRequest, 1239 }, 1240 { 1241 "invalid connection header", 1242 func() (*Options, *testResponseWriter, *http.Request) { 1243 opts := testWSOptions() 1244 req := testWSCreateValidReq() 1245 req.Header.Del("Connection") 1246 return opts, nil, req 1247 }, 1248 "invalid value for header 'Connection'", 1249 http.StatusBadRequest, 1250 }, 1251 { 1252 "no key", 1253 func() (*Options, *testResponseWriter, *http.Request) { 1254 opts := testWSOptions() 1255 req := testWSCreateValidReq() 1256 req.Header.Del("Sec-Websocket-Key") 1257 return opts, nil, req 1258 }, 1259 "key missing", 1260 http.StatusBadRequest, 1261 }, 1262 { 1263 "empty key", 1264 func() (*Options, *testResponseWriter, *http.Request) { 1265 opts := testWSOptions() 1266 req := testWSCreateValidReq() 1267 req.Header.Set("Sec-Websocket-Key", "") 1268 return opts, nil, req 1269 }, 1270 "key missing", 1271 http.StatusBadRequest, 1272 }, 1273 { 1274 "missing version", 1275 func() (*Options, *testResponseWriter, *http.Request) { 1276 opts := testWSOptions() 1277 req := testWSCreateValidReq() 1278 req.Header.Del("Sec-Websocket-Version") 1279 return opts, nil, req 1280 }, 1281 "invalid version", 1282 http.StatusBadRequest, 1283 }, 1284 { 1285 "wrong version", 1286 func() (*Options, *testResponseWriter, *http.Request) { 1287 opts := testWSOptions() 1288 req := testWSCreateValidReq() 1289 req.Header.Set("Sec-Websocket-Version", "99") 1290 return opts, nil, req 1291 }, 1292 "invalid version", 1293 http.StatusBadRequest, 1294 }, 1295 { 1296 "origin", 1297 func() (*Options, *testResponseWriter, *http.Request) { 1298 opts := testWSOptions() 1299 opts.Websocket.SameOrigin = true 1300 req := testWSCreateValidReq() 1301 req.Header.Set("Origin", "http://bad.host.com") 1302 return opts, nil, req 1303 }, 1304 "origin not allowed", 1305 http.StatusForbidden, 1306 }, 1307 { 1308 "hijack error", 1309 func() (*Options, *testResponseWriter, *http.Request) { 1310 opts := testWSOptions() 1311 rw := &testResponseWriter{err: fmt.Errorf("on purpose")} 1312 req := testWSCreateValidReq() 1313 return opts, rw, req 1314 }, 1315 "on purpose", 1316 http.StatusInternalServerError, 1317 }, 1318 { 1319 "hijack buffered data", 1320 func() (*Options, *testResponseWriter, *http.Request) { 1321 opts := testWSOptions() 1322 buf := &bytes.Buffer{} 1323 buf.WriteString("some data") 1324 rw := &testResponseWriter{ 1325 conn: &testWSFakeNetConn{}, 1326 brw: bufio.NewReadWriter(bufio.NewReader(buf), bufio.NewWriter(nil)), 1327 } 1328 tmp := [1]byte{} 1329 io.ReadAtLeast(rw.brw, tmp[:1], 1) 1330 req := testWSCreateValidReq() 1331 return opts, rw, req 1332 }, 1333 "client sent data before handshake is complete", 1334 http.StatusBadRequest, 1335 }, 1336 } { 1337 t.Run(test.name, func(t *testing.T) { 1338 opts, rw, req := test.setup() 1339 if rw == nil { 1340 rw = &testResponseWriter{} 1341 } 1342 s := &Server{opts: opts} 1343 s.wsSetOriginOptions(&opts.Websocket) 1344 res, err := s.wsUpgrade(rw, req) 1345 if err == nil || !strings.Contains(err.Error(), test.err) { 1346 t.Fatalf("Should get error %q, got %v", test.err, err) 1347 } 1348 if res != nil { 1349 t.Fatalf("Should not have returned a result, got %v", res) 1350 } 1351 expected := fmt.Sprintf("%v%s\n", test.status, http.StatusText(test.status)) 1352 if got := rw.buf.String(); got != expected { 1353 t.Fatalf("Expected %q got %q", expected, got) 1354 } 1355 // Check that if the connection was opened, it is now closed. 1356 if rw.conn != nil && rw.conn.wsOpened && !rw.conn.isClosed { 1357 t.Fatal("Connection was opened, but has not been closed") 1358 } 1359 }) 1360 } 1361 } 1362 1363 func TestWSUpgradeResponseWriteError(t *testing.T) { 1364 opts := testWSOptions() 1365 s := &Server{opts: opts} 1366 expectedErr := errors.New("on purpose") 1367 rw := &testResponseWriter{ 1368 conn: &testWSFakeNetConn{err: expectedErr}, 1369 } 1370 req := testWSCreateValidReq() 1371 res, err := s.wsUpgrade(rw, req) 1372 if err != expectedErr { 1373 t.Fatalf("Should get error %q, got %v", expectedErr.Error(), err) 1374 } 1375 if res != nil { 1376 t.Fatalf("Should not have returned a result, got %v", res) 1377 } 1378 if !rw.conn.isClosed { 1379 t.Fatal("Connection should have been closed") 1380 } 1381 } 1382 1383 func TestWSUpgradeConnDeadline(t *testing.T) { 1384 opts := testWSOptions() 1385 opts.Websocket.HandshakeTimeout = time.Second 1386 s := &Server{opts: opts} 1387 rw := &testResponseWriter{} 1388 req := testWSCreateValidReq() 1389 res, err := s.wsUpgrade(rw, req) 1390 if res == nil || err != nil { 1391 t.Fatalf("Unexpected error: %v", err) 1392 } 1393 if rw.conn.isClosed { 1394 t.Fatal("Connection should NOT have been closed") 1395 } 1396 if !rw.conn.deadlineCleared { 1397 t.Fatal("Connection deadline should have been cleared after handshake") 1398 } 1399 } 1400 1401 func TestWSCompressNegotiation(t *testing.T) { 1402 // No compression on the server, but client asks 1403 opts := testWSOptions() 1404 s := &Server{opts: opts} 1405 rw := &testResponseWriter{} 1406 req := testWSCreateValidReq() 1407 req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate") 1408 res, err := s.wsUpgrade(rw, req) 1409 if res == nil || err != nil { 1410 t.Fatalf("Unexpected error: %v", err) 1411 } 1412 // The http response should not contain "permessage-deflate" 1413 output := rw.conn.wbuf.String() 1414 if strings.Contains(output, "permessage-deflate") { 1415 t.Fatalf("Compression disabled in server so response to client should not contain extension, got %s", output) 1416 } 1417 1418 // Option in the server and client, so compression should be negotiated. 1419 s.opts.Websocket.Compression = true 1420 rw = &testResponseWriter{} 1421 res, err = s.wsUpgrade(rw, req) 1422 if res == nil || err != nil { 1423 t.Fatalf("Unexpected error: %v", err) 1424 } 1425 // The http response should not contain "permessage-deflate" 1426 output = rw.conn.wbuf.String() 1427 if !strings.Contains(output, "permessage-deflate") { 1428 t.Fatalf("Compression in server and client request, so response should contain extension, got %s", output) 1429 } 1430 1431 // Option in server but not asked by the client, so response should not contain "permessage-deflate" 1432 rw = &testResponseWriter{} 1433 req.Header.Del("Sec-Websocket-Extensions") 1434 res, err = s.wsUpgrade(rw, req) 1435 if res == nil || err != nil { 1436 t.Fatalf("Unexpected error: %v", err) 1437 } 1438 // The http response should not contain "permessage-deflate" 1439 output = rw.conn.wbuf.String() 1440 if strings.Contains(output, "permessage-deflate") { 1441 t.Fatalf("Compression in server but not in client, so response to client should not contain extension, got %s", output) 1442 } 1443 } 1444 1445 func TestWSSetHeader(t *testing.T) { 1446 opts := testWSOptions() 1447 opts.Websocket.Headers = map[string]string{ 1448 "X-Header": "some-value", 1449 "X-Another-Header": "another-value", 1450 } 1451 s := &Server{opts: opts} 1452 s.wsSetHeadersOptions(&opts.Websocket) 1453 rw := &testResponseWriter{} 1454 req := testWSCreateValidReq() 1455 res, err := s.wsUpgrade(rw, req) 1456 if res == nil || err != nil { 1457 t.Fatalf("Unexpected error: %v", err) 1458 } 1459 1460 buf := bufio.NewReader(&rw.conn.wbuf) 1461 resp, err := http.ReadResponse(buf, req) 1462 if err != nil { 1463 t.Fatalf("Error reading request: %v", err) 1464 } 1465 defer resp.Body.Close() 1466 1467 // Check that the response is a 101 1468 if resp.StatusCode != http.StatusSwitchingProtocols { 1469 t.Fatalf("Expected 101, got %v", resp.StatusCode) 1470 } 1471 1472 headers := resp.Header.Clone() 1473 1474 // Compare all the headers 1475 for k, v := range opts.Websocket.Headers { 1476 if got := headers.Get(k); got != v { 1477 t.Fatalf("Expected %q for header %q, got %q", v, k, got) 1478 } 1479 headers.Del(k) 1480 } 1481 1482 // Check remain headers 1483 for k, v := range map[string]string{ 1484 "Upgrade": "websocket", 1485 "Connection": "Upgrade", 1486 "Sec-Websocket-Accept": wsAcceptKey(req.Header.Get("Sec-Websocket-Key")), 1487 } { 1488 if got := headers.Get(k); got != v { 1489 t.Fatalf("Expected %q for header %q, got %q", v, k, got) 1490 } 1491 headers.Del(k) 1492 } 1493 1494 // Check that we have no more headers 1495 if len(headers) > 0 { 1496 t.Fatalf("Unexpected headers: %v", headers) 1497 } 1498 } 1499 1500 func TestWSParseOptions(t *testing.T) { 1501 for _, test := range []struct { 1502 name string 1503 content string 1504 checkOpt func(*WebsocketOpts) error 1505 err string 1506 }{ 1507 // Negative tests 1508 {"bad type", "websocket: []", nil, "to be a map"}, 1509 {"bad listen", "websocket: { listen: [] }", nil, "port or host:port"}, 1510 {"bad port", `websocket: { port: "abc" }`, nil, "not int64"}, 1511 {"bad host", `websocket: { host: 123 }`, nil, "not string"}, 1512 {"bad advertise type", `websocket: { advertise: 123 }`, nil, "not string"}, 1513 {"bad tls", `websocket: { tls: 123 }`, nil, "not map[string]interface {}"}, 1514 {"bad same origin", `websocket: { same_origin: "abc" }`, nil, "not bool"}, 1515 {"bad allowed origins type", `websocket: { allowed_origins: {} }`, nil, "unsupported type"}, 1516 {"bad allowed origins values", `websocket: { allowed_origins: [ {} ] }`, nil, "unsupported type in array"}, 1517 {"bad handshake timeout type", `websocket: { handshake_timeout: [] }`, nil, "unsupported type"}, 1518 {"bad handshake timeout duration", `websocket: { handshake_timeout: "abc" }`, nil, "invalid duration"}, 1519 {"bad header type", `websocket: { headers: 123 }`, nil, "unsupported type"}, 1520 {"bad header type", `websocket: { headers: [] }`, nil, "unsupported type"}, 1521 {"bad header value", `websocket: { headers: { "key": 123 } }`, nil, "unsupported type"}, 1522 {"unknown field", `websocket: { this_does_not_exist: 123 }`, nil, "unknown"}, 1523 // Positive tests 1524 {"listen port only", `websocket { listen: 1234 }`, func(wo *WebsocketOpts) error { 1525 if wo.Port != 1234 { 1526 return fmt.Errorf("expected 1234, got %v", wo.Port) 1527 } 1528 return nil 1529 }, ""}, 1530 {"listen host and port", `websocket { listen: "localhost:1234" }`, func(wo *WebsocketOpts) error { 1531 if wo.Host != "localhost" || wo.Port != 1234 { 1532 return fmt.Errorf("expected localhost:1234, got %v:%v", wo.Host, wo.Port) 1533 } 1534 return nil 1535 }, ""}, 1536 {"host", `websocket { host: "localhost" }`, func(wo *WebsocketOpts) error { 1537 if wo.Host != "localhost" { 1538 return fmt.Errorf("expected localhost, got %v", wo.Host) 1539 } 1540 return nil 1541 }, ""}, 1542 {"port", `websocket { port: 1234 }`, func(wo *WebsocketOpts) error { 1543 if wo.Port != 1234 { 1544 return fmt.Errorf("expected 1234, got %v", wo.Port) 1545 } 1546 return nil 1547 }, ""}, 1548 {"advertise", `websocket { advertise: "host:1234" }`, func(wo *WebsocketOpts) error { 1549 if wo.Advertise != "host:1234" { 1550 return fmt.Errorf("expected %q, got %q", "host:1234", wo.Advertise) 1551 } 1552 return nil 1553 }, ""}, 1554 {"same origin", `websocket { same_origin: true }`, func(wo *WebsocketOpts) error { 1555 if !wo.SameOrigin { 1556 return fmt.Errorf("expected same_origin==true, got %v", wo.SameOrigin) 1557 } 1558 return nil 1559 }, ""}, 1560 {"allowed origins one only", `websocket { allowed_origins: "https://host.com/" }`, func(wo *WebsocketOpts) error { 1561 expected := []string{"https://host.com/"} 1562 if !reflect.DeepEqual(wo.AllowedOrigins, expected) { 1563 return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins) 1564 } 1565 return nil 1566 }, ""}, 1567 {"allowed origins array", 1568 ` 1569 websocket { 1570 allowed_origins: [ 1571 "https://host1.com/" 1572 "https://host2.com/" 1573 ] 1574 } 1575 `, func(wo *WebsocketOpts) error { 1576 expected := []string{"https://host1.com/", "https://host2.com/"} 1577 if !reflect.DeepEqual(wo.AllowedOrigins, expected) { 1578 return fmt.Errorf("expected allowed origins to be %q, got %q", expected, wo.AllowedOrigins) 1579 } 1580 return nil 1581 }, ""}, 1582 {"handshake timeout in whole seconds", `websocket { handshake_timeout: 3 }`, func(wo *WebsocketOpts) error { 1583 if wo.HandshakeTimeout != 3*time.Second { 1584 return fmt.Errorf("expected handshake to be 3s, got %v", wo.HandshakeTimeout) 1585 } 1586 return nil 1587 }, ""}, 1588 {"handshake timeout n duration", `websocket { handshake_timeout: "4s" }`, func(wo *WebsocketOpts) error { 1589 if wo.HandshakeTimeout != 4*time.Second { 1590 return fmt.Errorf("expected handshake to be 4s, got %v", wo.HandshakeTimeout) 1591 } 1592 return nil 1593 }, ""}, 1594 {"tls config", 1595 ` 1596 websocket { 1597 tls { 1598 cert_file: "./configs/certs/server.pem" 1599 key_file: "./configs/certs/key.pem" 1600 } 1601 } 1602 `, func(wo *WebsocketOpts) error { 1603 if wo.TLSConfig == nil { 1604 return fmt.Errorf("TLSConfig should have been set") 1605 } 1606 return nil 1607 }, ""}, 1608 {"compression", 1609 ` 1610 websocket { 1611 compression: true 1612 } 1613 `, func(wo *WebsocketOpts) error { 1614 if !wo.Compression { 1615 return fmt.Errorf("Compression should have been set") 1616 } 1617 return nil 1618 }, ""}, 1619 {"jwt cookie", 1620 ` 1621 websocket { 1622 jwt_cookie: "jwtcookie" 1623 } 1624 `, func(wo *WebsocketOpts) error { 1625 if wo.JWTCookie != "jwtcookie" { 1626 return fmt.Errorf("Invalid JWTCookie value: %q", wo.JWTCookie) 1627 } 1628 return nil 1629 }, ""}, 1630 {"no auth user", 1631 ` 1632 websocket { 1633 no_auth_user: "noauthuser" 1634 } 1635 `, func(wo *WebsocketOpts) error { 1636 if wo.NoAuthUser != "noauthuser" { 1637 return fmt.Errorf("Invalid NoAuthUser value: %q", wo.NoAuthUser) 1638 } 1639 return nil 1640 }, ""}, 1641 {"auth block", 1642 ` 1643 websocket { 1644 authorization { 1645 user: "webuser" 1646 password: "pwd" 1647 token: "token" 1648 timeout: 2.0 1649 } 1650 } 1651 `, func(wo *WebsocketOpts) error { 1652 if wo.Username != "webuser" || wo.Password != "pwd" || wo.Token != "token" || wo.AuthTimeout != 2.0 { 1653 return fmt.Errorf("Invalid auth block: %+v", wo) 1654 } 1655 return nil 1656 }, ""}, 1657 {"auth timeout as int", 1658 ` 1659 websocket { 1660 authorization { 1661 timeout: 2 1662 } 1663 } 1664 `, func(wo *WebsocketOpts) error { 1665 if wo.AuthTimeout != 2.0 { 1666 return fmt.Errorf("Invalid auth timeout: %v", wo.AuthTimeout) 1667 } 1668 return nil 1669 }, ""}, 1670 {"headers block", 1671 ` 1672 websocket { 1673 headers { 1674 "X-Header": "some-value" 1675 "X-Another-Header": "another-value" 1676 } 1677 } 1678 `, func(wo *WebsocketOpts) error { 1679 if len(wo.Headers) != 2 { 1680 return fmt.Errorf("Expected 2 headers, got %v", len(wo.Headers)) 1681 } 1682 1683 for k, v := range map[string]string{ 1684 "X-Header": "some-value", 1685 "X-Another-Header": "another-value", 1686 } { 1687 if got, ok := wo.Headers[k]; !ok || got != v { 1688 return fmt.Errorf("Invalid value for %q: %q", k, got) 1689 } 1690 } 1691 return nil 1692 }, ""}, 1693 } { 1694 t.Run(test.name, func(t *testing.T) { 1695 conf := createConfFile(t, []byte(test.content)) 1696 o, err := ProcessConfigFile(conf) 1697 if test.err != _EMPTY_ { 1698 if err == nil || !strings.Contains(err.Error(), test.err) { 1699 t.Fatalf("For content: %q, expected error about %q, got %v", test.content, test.err, err) 1700 } 1701 return 1702 } else if err != nil { 1703 t.Fatalf("Unexpected error for content %q: %v", test.content, err) 1704 } 1705 if err := test.checkOpt(&o.Websocket); err != nil { 1706 t.Fatalf("Incorrect option for content %q: %v", test.content, err.Error()) 1707 } 1708 }) 1709 } 1710 } 1711 1712 func TestWSValidateOptions(t *testing.T) { 1713 nwso := DefaultOptions() 1714 wso := testWSOptions() 1715 for _, test := range []struct { 1716 name string 1717 getOpts func() *Options 1718 err string 1719 }{ 1720 {"websocket disabled", func() *Options { return nwso.Clone() }, ""}, 1721 {"no tls", func() *Options { o := wso.Clone(); o.Websocket.TLSConfig = nil; return o }, "requires TLS configuration"}, 1722 {"bad url in allowed list", func() *Options { 1723 o := wso.Clone() 1724 o.Websocket.AllowedOrigins = []string{"http://this:is:bad:url"} 1725 return o 1726 }, "unable to parse"}, 1727 {"missing trusted configuration", func() *Options { 1728 o := wso.Clone() 1729 o.Websocket.JWTCookie = "jwt" 1730 return o 1731 }, "keys configuration is required"}, 1732 {"websocket username not allowed if users specified", func() *Options { 1733 o := wso.Clone() 1734 o.Nkeys = []*NkeyUser{{Nkey: "abc"}} 1735 o.Websocket.Username = "b" 1736 o.Websocket.Password = "pwd" 1737 return o 1738 }, "websocket authentication username not compatible with presence of users/nkeys"}, 1739 {"websocket token not allowed if users specified", func() *Options { 1740 o := wso.Clone() 1741 o.Nkeys = []*NkeyUser{{Nkey: "abc"}} 1742 o.Websocket.Token = "mytoken" 1743 return o 1744 }, "websocket authentication token not compatible with presence of users/nkeys"}, 1745 {"headers with sec-websocket- prefix not allowed", func() *Options { 1746 o := wso.Clone() 1747 o.Websocket.Headers = map[string]string{"Sec-WebSocket-Key": "123"} 1748 return o 1749 }, `invalid header "Sec-WebSocket-Key", "Sec-WebSocket-" prefix not allowed`}, 1750 {"header with host", func() *Options { 1751 o := wso.Clone() 1752 o.Websocket.Headers = map[string]string{"Host": "http://localhost:8080"} 1753 return o 1754 }, `websocket: invalid header "Host" not allowed`}, 1755 {"header with content-length", func() *Options { 1756 o := wso.Clone() 1757 o.Websocket.Headers = map[string]string{"Content-Length": "0"} 1758 return o 1759 }, `websocket: invalid header "Content-Length" not allowed`}, 1760 {"header with connection", func() *Options { 1761 o := wso.Clone() 1762 o.Websocket.Headers = map[string]string{"Connection": "Upgrade"} 1763 return o 1764 }, `websocket: invalid header "Connection" not allowed`}, 1765 {"header with upgrade", func() *Options { 1766 o := wso.Clone() 1767 o.Websocket.Headers = map[string]string{"Upgrade": "websocket"} 1768 return o 1769 }, `websocket: invalid header "Upgrade" not allowed`}, 1770 {"header with Nats-No-Masking", func() *Options { 1771 o := wso.Clone() 1772 o.Websocket.Headers = map[string]string{"Nats-No-Masking": "false"} 1773 return o 1774 }, `websocket: invalid header "Nats-No-Masking" not allowed`}, 1775 } { 1776 t.Run(test.name, func(t *testing.T) { 1777 err := validateWebsocketOptions(test.getOpts()) 1778 if test.err == "" && err != nil { 1779 t.Fatalf("Unexpected error: %v", err) 1780 } else if test.err != "" && (err == nil || !strings.Contains(err.Error(), test.err)) { 1781 t.Fatalf("Expected error to contain %q, got %v", test.err, err) 1782 } 1783 }) 1784 } 1785 } 1786 1787 func TestWSSetOriginOptions(t *testing.T) { 1788 o := testWSOptions() 1789 for _, test := range []struct { 1790 content string 1791 err string 1792 }{ 1793 {"@@@://host.com/", "invalid URI"}, 1794 {"http://this:is:bad:url/", "invalid port"}, 1795 } { 1796 t.Run(test.err, func(t *testing.T) { 1797 o.Websocket.AllowedOrigins = []string{test.content} 1798 s := &Server{} 1799 l := &captureErrorLogger{errCh: make(chan string, 1)} 1800 s.SetLogger(l, false, false) 1801 s.wsSetOriginOptions(&o.Websocket) 1802 select { 1803 case e := <-l.errCh: 1804 if !strings.Contains(e, test.err) { 1805 t.Fatalf("Unexpected error: %v", e) 1806 } 1807 case <-time.After(50 * time.Millisecond): 1808 t.Fatalf("Did not get the error") 1809 } 1810 1811 }) 1812 } 1813 } 1814 1815 type captureFatalLogger struct { 1816 DummyLogger 1817 fatalCh chan string 1818 } 1819 1820 func (l *captureFatalLogger) Fatalf(format string, v ...any) { 1821 select { 1822 case l.fatalCh <- fmt.Sprintf(format, v...): 1823 default: 1824 } 1825 } 1826 1827 func TestWSFailureToStartServer(t *testing.T) { 1828 // Create a listener to use a port 1829 l, err := net.Listen("tcp", "127.0.0.1:0") 1830 if err != nil { 1831 t.Fatalf("Error listening: %v", err) 1832 } 1833 defer l.Close() 1834 1835 o := testWSOptions() 1836 // Make sure we don't have unnecessary listen ports opened. 1837 o.HTTPPort = 0 1838 o.Cluster.Port = 0 1839 o.Gateway.Name = "" 1840 o.Gateway.Port = 0 1841 o.LeafNode.Port = 0 1842 o.Websocket.Port = l.Addr().(*net.TCPAddr).Port 1843 s, err := NewServer(o) 1844 if err != nil { 1845 t.Fatalf("Error creating server: %v", err) 1846 } 1847 defer s.Shutdown() 1848 logger := &captureFatalLogger{fatalCh: make(chan string, 1)} 1849 s.SetLogger(logger, false, false) 1850 1851 wg := sync.WaitGroup{} 1852 wg.Add(1) 1853 go func() { 1854 s.Start() 1855 wg.Done() 1856 }() 1857 1858 select { 1859 case e := <-logger.fatalCh: 1860 if !strings.Contains(e, "Unable to listen") { 1861 t.Fatalf("Unexpected error: %v", e) 1862 } 1863 case <-time.After(2 * time.Second): 1864 t.Fatalf("Should have reported a fatal error") 1865 } 1866 // Since this is a test and the process does not actually 1867 // exit on Fatal error, wait for the client port to be 1868 // ready so when we shutdown we don't leave the accept 1869 // loop hanging. 1870 checkFor(t, time.Second, 15*time.Millisecond, func() error { 1871 s.mu.Lock() 1872 ready := s.listener != nil 1873 s.mu.Unlock() 1874 if !ready { 1875 return fmt.Errorf("client accept loop not started yet") 1876 } 1877 return nil 1878 }) 1879 s.Shutdown() 1880 wg.Wait() 1881 } 1882 1883 func TestWSAbnormalFailureOfWebServer(t *testing.T) { 1884 o := testWSOptions() 1885 s := RunServer(o) 1886 defer s.Shutdown() 1887 logger := &captureFatalLogger{fatalCh: make(chan string, 1)} 1888 s.SetLogger(logger, false, false) 1889 1890 // Now close the WS listener to cause a WebServer error 1891 s.mu.Lock() 1892 s.websocket.listener.Close() 1893 s.mu.Unlock() 1894 1895 select { 1896 case e := <-logger.fatalCh: 1897 if !strings.Contains(e, "websocket listener error") { 1898 t.Fatalf("Unexpected error: %v", e) 1899 } 1900 case <-time.After(2 * time.Second): 1901 t.Fatalf("Should have reported a fatal error") 1902 } 1903 } 1904 1905 type testWSClientOptions struct { 1906 compress, web bool 1907 host string 1908 port int 1909 extraHeaders map[string][]string 1910 noTLS bool 1911 path string 1912 extraResponseHeaders map[string]string 1913 } 1914 1915 func testNewWSClient(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte) { 1916 t.Helper() 1917 c, br, info, err := testNewWSClientWithError(t, o) 1918 if err != nil { 1919 t.Fatal(err) 1920 } 1921 return c, br, info 1922 } 1923 1924 func testNewWSClientWithError(t testing.TB, o testWSClientOptions) (net.Conn, *bufio.Reader, []byte, error) { 1925 addr := fmt.Sprintf("%s:%d", o.host, o.port) 1926 wsc, err := net.Dial("tcp", addr) 1927 if err != nil { 1928 return nil, nil, nil, fmt.Errorf("Error creating ws connection: %v", err) 1929 } 1930 if !o.noTLS { 1931 wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) 1932 wsc.SetDeadline(time.Now().Add(time.Second)) 1933 if err := wsc.(*tls.Conn).Handshake(); err != nil { 1934 return nil, nil, nil, fmt.Errorf("Error during handshake: %v", err) 1935 } 1936 wsc.SetDeadline(time.Time{}) 1937 } 1938 req := testWSCreateValidReq() 1939 if o.compress { 1940 req.Header.Set("Sec-Websocket-Extensions", "permessage-deflate") 1941 } 1942 if o.web { 1943 req.Header.Set("User-Agent", "Mozilla/5.0") 1944 } 1945 if len(o.extraHeaders) > 0 { 1946 for hdr, values := range o.extraHeaders { 1947 if len(values) == 0 { 1948 req.Header.Set(hdr, _EMPTY_) 1949 continue 1950 } 1951 req.Header.Set(hdr, values[0]) 1952 for i := 1; i < len(values); i++ { 1953 req.Header.Add(hdr, values[i]) 1954 } 1955 } 1956 } 1957 req.URL, _ = url.Parse("wss://" + addr + o.path) 1958 if err := req.Write(wsc); err != nil { 1959 return nil, nil, nil, fmt.Errorf("Error sending request: %v", err) 1960 } 1961 br := bufio.NewReader(wsc) 1962 resp, err := http.ReadResponse(br, req) 1963 if err != nil { 1964 return nil, nil, nil, fmt.Errorf("Error reading response: %v", err) 1965 } 1966 defer resp.Body.Close() 1967 if resp.StatusCode != http.StatusSwitchingProtocols { 1968 return nil, nil, nil, fmt.Errorf("Expected response status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) 1969 } 1970 for k, v := range o.extraResponseHeaders { 1971 if value := resp.Header.Get(k); value != v { 1972 return nil, nil, nil, fmt.Errorf("Expected header %q to be %q, got %q", k, v, value) 1973 } 1974 } 1975 var info []byte 1976 if o.path == mqttWSPath { 1977 if v := resp.Header[wsSecProto]; len(v) != 1 || v[0] != wsMQTTSecProtoVal { 1978 return nil, nil, nil, fmt.Errorf("No mqtt protocol in header: %v", resp.Header) 1979 } 1980 } else { 1981 // Wait for the INFO 1982 info = testWSReadFrame(t, br) 1983 if !bytes.HasPrefix(info, []byte("INFO {")) { 1984 return nil, nil, nil, fmt.Errorf("Expected INFO, got %s", info) 1985 } 1986 } 1987 return wsc, br, info, nil 1988 } 1989 1990 type testClaimsOptions struct { 1991 nac *jwt.AccountClaims 1992 nuc *jwt.UserClaims 1993 connectRequest any 1994 dontSign bool 1995 expectAnswer string 1996 } 1997 1998 func testWSWithClaims(t *testing.T, s *Server, o testWSClientOptions, tclm testClaimsOptions) (kp nkeys.KeyPair, conn net.Conn, rdr *bufio.Reader, auth_was_required bool) { 1999 t.Helper() 2000 2001 okp, _ := nkeys.FromSeed(oSeed) 2002 2003 akp, _ := nkeys.CreateAccount() 2004 apub, _ := akp.PublicKey() 2005 if tclm.nac == nil { 2006 tclm.nac = jwt.NewAccountClaims(apub) 2007 } else { 2008 tclm.nac.Subject = apub 2009 } 2010 ajwt, err := tclm.nac.Encode(okp) 2011 if err != nil { 2012 t.Fatalf("Error generating account JWT: %v", err) 2013 } 2014 2015 nkp, _ := nkeys.CreateUser() 2016 pub, _ := nkp.PublicKey() 2017 if tclm.nuc == nil { 2018 tclm.nuc = jwt.NewUserClaims(pub) 2019 } else { 2020 tclm.nuc.Subject = pub 2021 } 2022 jwt, err := tclm.nuc.Encode(akp) 2023 if err != nil { 2024 t.Fatalf("Error generating user JWT: %v", err) 2025 } 2026 2027 addAccountToMemResolver(s, apub, ajwt) 2028 2029 c, cr, l := testNewWSClient(t, o) 2030 2031 var info struct { 2032 Nonce string `json:"nonce,omitempty"` 2033 AuthRequired bool `json:"auth_required,omitempty"` 2034 } 2035 2036 if err := json.Unmarshal([]byte(l[5:]), &info); err != nil { 2037 t.Fatal(err) 2038 } 2039 if info.AuthRequired { 2040 cs := "" 2041 if tclm.connectRequest != nil { 2042 customReq, err := json.Marshal(tclm.connectRequest) 2043 if err != nil { 2044 t.Fatal(err) 2045 } 2046 // PING needed to flush the +OK/-ERR to us. 2047 cs = fmt.Sprintf("CONNECT %v\r\nPING\r\n", string(customReq)) 2048 } else if !tclm.dontSign { 2049 // Sign Nonce 2050 sigraw, _ := nkp.Sign([]byte(info.Nonce)) 2051 sig := base64.RawURLEncoding.EncodeToString(sigraw) 2052 cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"sig\":\"%s\",\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt, sig) 2053 } else { 2054 cs = fmt.Sprintf("CONNECT {\"jwt\":%q,\"verbose\":true,\"pedantic\":true}\r\nPING\r\n", jwt) 2055 } 2056 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(cs)) 2057 c.Write(wsmsg) 2058 l = testWSReadFrame(t, cr) 2059 if !strings.HasPrefix(string(l), tclm.expectAnswer) { 2060 t.Fatalf("Expected %q, got %q", tclm.expectAnswer, l) 2061 } 2062 } 2063 return akp, c, cr, info.AuthRequired 2064 } 2065 2066 func setupAddTrusted(o *Options) { 2067 kp, _ := nkeys.FromSeed(oSeed) 2068 pub, _ := kp.PublicKey() 2069 o.TrustedKeys = []string{pub} 2070 } 2071 2072 func setupAddCookie(o *Options) { 2073 o.Websocket.JWTCookie = "jwt" 2074 } 2075 2076 func testWSCreateClientGetInfo(t testing.TB, compress, web bool, host string, port int, cookies ...string) (net.Conn, *bufio.Reader, []byte) { 2077 t.Helper() 2078 opts := testWSClientOptions{ 2079 compress: compress, 2080 web: web, 2081 host: host, 2082 port: port, 2083 } 2084 2085 if len(cookies) > 0 { 2086 opts.extraHeaders = map[string][]string{} 2087 opts.extraHeaders["Cookie"] = cookies 2088 } 2089 return testNewWSClient(t, opts) 2090 } 2091 2092 func testWSCreateClient(t testing.TB, compress, web bool, host string, port int) (net.Conn, *bufio.Reader) { 2093 wsc, br, _ := testWSCreateClientGetInfo(t, compress, web, host, port) 2094 // Send CONNECT and PING 2095 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n")) 2096 if _, err := wsc.Write(wsmsg); err != nil { 2097 t.Fatalf("Error sending message: %v", err) 2098 } 2099 // Wait for the PONG 2100 if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 2101 t.Fatalf("Expected PONG, got %s", msg) 2102 } 2103 return wsc, br 2104 } 2105 2106 func testWSReadFrame(t testing.TB, br *bufio.Reader) []byte { 2107 t.Helper() 2108 fh := [2]byte{} 2109 if _, err := io.ReadAtLeast(br, fh[:2], 2); err != nil { 2110 t.Fatalf("Error reading frame: %v", err) 2111 } 2112 fc := fh[0]&wsRsv1Bit != 0 2113 sb := fh[1] 2114 size := 0 2115 switch { 2116 case sb <= 125: 2117 size = int(sb) 2118 case sb == 126: 2119 tmp := [2]byte{} 2120 if _, err := io.ReadAtLeast(br, tmp[:2], 2); err != nil { 2121 t.Fatalf("Error reading frame: %v", err) 2122 } 2123 size = int(binary.BigEndian.Uint16(tmp[:2])) 2124 case sb == 127: 2125 tmp := [8]byte{} 2126 if _, err := io.ReadAtLeast(br, tmp[:8], 8); err != nil { 2127 t.Fatalf("Error reading frame: %v", err) 2128 } 2129 size = int(binary.BigEndian.Uint64(tmp[:8])) 2130 } 2131 buf := make([]byte, size) 2132 if _, err := io.ReadAtLeast(br, buf, size); err != nil { 2133 t.Fatalf("Error reading frame: %v", err) 2134 } 2135 if !fc { 2136 return buf 2137 } 2138 buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff) 2139 dbr := bytes.NewBuffer(buf) 2140 d := flate.NewReader(dbr) 2141 uncompressed, err := io.ReadAll(d) 2142 if err != nil { 2143 t.Fatalf("Error reading frame: %v", err) 2144 } 2145 return uncompressed 2146 } 2147 2148 func TestWSPubSub(t *testing.T) { 2149 for _, test := range []struct { 2150 name string 2151 compression bool 2152 }{ 2153 {"no compression", false}, 2154 {"compression", true}, 2155 } { 2156 t.Run(test.name, func(t *testing.T) { 2157 o := testWSOptions() 2158 if test.compression { 2159 o.Websocket.Compression = true 2160 } 2161 s := RunServer(o) 2162 defer s.Shutdown() 2163 2164 // Create a regular client to subscribe 2165 nc := natsConnect(t, s.ClientURL()) 2166 defer nc.Close() 2167 nsub := natsSubSync(t, nc, "foo") 2168 checkExpectedSubs(t, 1, s) 2169 2170 // Now create a WS client and send a message on "foo" 2171 wsc, br := testWSCreateClient(t, test.compression, false, o.Websocket.Host, o.Websocket.Port) 2172 defer wsc.Close() 2173 2174 // Send a WS message for "PUB foo 2\r\nok\r\n" 2175 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n")) 2176 if _, err := wsc.Write(wsmsg); err != nil { 2177 t.Fatalf("Error sending message: %v", err) 2178 } 2179 2180 // Now check that message is received 2181 msg := natsNexMsg(t, nsub, time.Second) 2182 if string(msg.Data) != "from ws" { 2183 t.Fatalf("Expected message to be %q, got %q", "ok", string(msg.Data)) 2184 } 2185 2186 // Now do reverse, create a subscription on WS client on bar 2187 wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB bar 1\r\n")) 2188 if _, err := wsc.Write(wsmsg); err != nil { 2189 t.Fatalf("Error sending subscription: %v", err) 2190 } 2191 // Wait for it to be registered on server 2192 checkExpectedSubs(t, 2, s) 2193 // Now publish from NATS connection and verify received on WS client 2194 natsPub(t, nc, "bar", []byte("from nats")) 2195 natsFlush(t, nc) 2196 2197 // Check for the "from nats" message... 2198 // Set some deadline so we are not stuck forever on failure 2199 wsc.SetReadDeadline(time.Now().Add(10 * time.Second)) 2200 ok := 0 2201 for { 2202 line, _, err := br.ReadLine() 2203 if err != nil { 2204 t.Fatalf("Error reading: %v", err) 2205 } 2206 // Note that this works even in compression test because those 2207 // texts are likely not to be compressed, but compression code is 2208 // still executed. 2209 if ok == 0 && bytes.Contains(line, []byte("MSG bar 1 9")) { 2210 ok = 1 2211 continue 2212 } else if ok == 1 && bytes.Contains(line, []byte("from nats")) { 2213 break 2214 } 2215 } 2216 }) 2217 } 2218 } 2219 2220 func TestWSTLSConnection(t *testing.T) { 2221 o := testWSOptions() 2222 s := RunServer(o) 2223 defer s.Shutdown() 2224 2225 addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) 2226 2227 for _, test := range []struct { 2228 name string 2229 useTLS bool 2230 status int 2231 }{ 2232 {"client uses TLS", true, http.StatusSwitchingProtocols}, 2233 {"client does not use TLS", false, http.StatusBadRequest}, 2234 } { 2235 t.Run(test.name, func(t *testing.T) { 2236 wsc, err := net.Dial("tcp", addr) 2237 if err != nil { 2238 t.Fatalf("Error creating ws connection: %v", err) 2239 } 2240 defer wsc.Close() 2241 if test.useTLS { 2242 wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) 2243 if err := wsc.(*tls.Conn).Handshake(); err != nil { 2244 t.Fatalf("Error during handshake: %v", err) 2245 } 2246 } 2247 req := testWSCreateValidReq() 2248 var scheme string 2249 if test.useTLS { 2250 scheme = "s" 2251 } 2252 req.URL, _ = url.Parse("ws" + scheme + "://" + addr) 2253 if err := req.Write(wsc); err != nil { 2254 t.Fatalf("Error sending request: %v", err) 2255 } 2256 br := bufio.NewReader(wsc) 2257 resp, err := http.ReadResponse(br, req) 2258 if err != nil { 2259 t.Fatalf("Error reading response: %v", err) 2260 } 2261 defer resp.Body.Close() 2262 if resp.StatusCode != test.status { 2263 t.Fatalf("Expected status %v, got %v", test.status, resp.StatusCode) 2264 } 2265 }) 2266 } 2267 } 2268 2269 func TestWSTLSVerifyClientCert(t *testing.T) { 2270 o := testWSOptions() 2271 tc := &TLSConfigOpts{ 2272 CertFile: "../test/configs/certs/server-cert.pem", 2273 KeyFile: "../test/configs/certs/server-key.pem", 2274 CaFile: "../test/configs/certs/ca.pem", 2275 Verify: true, 2276 } 2277 tlsc, err := GenTLSConfig(tc) 2278 if err != nil { 2279 t.Fatalf("Error creating tls config: %v", err) 2280 } 2281 o.Websocket.TLSConfig = tlsc 2282 s := RunServer(o) 2283 defer s.Shutdown() 2284 2285 addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) 2286 2287 for _, test := range []struct { 2288 name string 2289 provideCert bool 2290 }{ 2291 {"client provides cert", true}, 2292 {"client does not provide cert", false}, 2293 } { 2294 t.Run(test.name, func(t *testing.T) { 2295 wsc, err := net.Dial("tcp", addr) 2296 if err != nil { 2297 t.Fatalf("Error creating ws connection: %v", err) 2298 } 2299 defer wsc.Close() 2300 tlsc := &tls.Config{} 2301 if test.provideCert { 2302 tc := &TLSConfigOpts{ 2303 CertFile: "../test/configs/certs/client-cert.pem", 2304 KeyFile: "../test/configs/certs/client-key.pem", 2305 } 2306 var err error 2307 tlsc, err = GenTLSConfig(tc) 2308 if err != nil { 2309 t.Fatalf("Error generating tls config: %v", err) 2310 } 2311 } 2312 tlsc.InsecureSkipVerify = true 2313 wsc = tls.Client(wsc, tlsc) 2314 if err := wsc.(*tls.Conn).Handshake(); err != nil { 2315 t.Fatalf("Error during handshake: %v", err) 2316 } 2317 req := testWSCreateValidReq() 2318 req.URL, _ = url.Parse("wss://" + addr) 2319 if err := req.Write(wsc); err != nil { 2320 t.Fatalf("Error sending request: %v", err) 2321 } 2322 br := bufio.NewReader(wsc) 2323 resp, err := http.ReadResponse(br, req) 2324 if resp != nil { 2325 resp.Body.Close() 2326 } 2327 if !test.provideCert { 2328 if err == nil { 2329 t.Fatal("Expected error, did not get one") 2330 } else if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "certificate required") { 2331 t.Fatalf("Unexpected error: %v", err) 2332 } 2333 return 2334 } 2335 if err != nil { 2336 t.Fatalf("Unexpected error: %v", err) 2337 } 2338 if resp.StatusCode != http.StatusSwitchingProtocols { 2339 t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) 2340 } 2341 }) 2342 } 2343 } 2344 2345 func testCreateAllowedConnectionTypes(list []string) map[string]struct{} { 2346 if len(list) == 0 { 2347 return nil 2348 } 2349 m := make(map[string]struct{}, len(list)) 2350 for _, l := range list { 2351 m[l] = struct{}{} 2352 } 2353 return m 2354 } 2355 2356 func TestWSTLSVerifyAndMap(t *testing.T) { 2357 accName := "MyAccount" 2358 acc := NewAccount(accName) 2359 certUserName := "CN=example.com,OU=NATS.io" 2360 users := []*User{{Username: certUserName, Account: acc}} 2361 2362 for _, test := range []struct { 2363 name string 2364 filtering bool 2365 provideCert bool 2366 }{ 2367 {"no filtering, client provides cert", false, true}, 2368 {"no filtering, client does not provide cert", false, false}, 2369 {"filtering, client provides cert", true, true}, 2370 {"filtering, client does not provide cert", true, false}, 2371 {"no users override, client provides cert", false, true}, 2372 {"no users override, client does not provide cert", false, false}, 2373 {"users override, client provides cert", true, true}, 2374 {"users override, client does not provide cert", true, false}, 2375 } { 2376 t.Run(test.name, func(t *testing.T) { 2377 o := testWSOptions() 2378 o.Accounts = []*Account{acc} 2379 o.Users = users 2380 if test.filtering { 2381 o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}) 2382 } 2383 tc := &TLSConfigOpts{ 2384 CertFile: "../test/configs/certs/tlsauth/server.pem", 2385 KeyFile: "../test/configs/certs/tlsauth/server-key.pem", 2386 CaFile: "../test/configs/certs/tlsauth/ca.pem", 2387 Verify: true, 2388 } 2389 tlsc, err := GenTLSConfig(tc) 2390 if err != nil { 2391 t.Fatalf("Error creating tls config: %v", err) 2392 } 2393 o.Websocket.TLSConfig = tlsc 2394 o.Websocket.TLSMap = true 2395 s := RunServer(o) 2396 defer s.Shutdown() 2397 2398 addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) 2399 wsc, err := net.Dial("tcp", addr) 2400 if err != nil { 2401 t.Fatalf("Error creating ws connection: %v", err) 2402 } 2403 defer wsc.Close() 2404 tlscc := &tls.Config{} 2405 if test.provideCert { 2406 tc := &TLSConfigOpts{ 2407 CertFile: "../test/configs/certs/tlsauth/client.pem", 2408 KeyFile: "../test/configs/certs/tlsauth/client-key.pem", 2409 } 2410 var err error 2411 tlscc, err = GenTLSConfig(tc) 2412 if err != nil { 2413 t.Fatalf("Error generating tls config: %v", err) 2414 } 2415 } 2416 tlscc.InsecureSkipVerify = true 2417 wsc = tls.Client(wsc, tlscc) 2418 if err := wsc.(*tls.Conn).Handshake(); err != nil { 2419 t.Fatalf("Error during handshake: %v", err) 2420 } 2421 req := testWSCreateValidReq() 2422 req.URL, _ = url.Parse("wss://" + addr) 2423 if err := req.Write(wsc); err != nil { 2424 t.Fatalf("Error sending request: %v", err) 2425 } 2426 br := bufio.NewReader(wsc) 2427 resp, err := http.ReadResponse(br, req) 2428 if resp != nil { 2429 resp.Body.Close() 2430 } 2431 if !test.provideCert { 2432 if err == nil { 2433 t.Fatal("Expected error, did not get one") 2434 } else if !strings.Contains(err.Error(), "bad certificate") && !strings.Contains(err.Error(), "certificate required") { 2435 t.Fatalf("Unexpected error: %v", err) 2436 } 2437 return 2438 } 2439 if err != nil { 2440 t.Fatalf("Unexpected error: %v", err) 2441 } 2442 if resp.StatusCode != http.StatusSwitchingProtocols { 2443 t.Fatalf("Expected status %v, got %v", http.StatusSwitchingProtocols, resp.StatusCode) 2444 } 2445 // Wait for the INFO 2446 l := testWSReadFrame(t, br) 2447 if !bytes.HasPrefix(l, []byte("INFO {")) { 2448 t.Fatalf("Expected INFO, got %s", l) 2449 } 2450 var info serverInfo 2451 if err := json.Unmarshal(l[5:], &info); err != nil { 2452 t.Fatalf("Unable to unmarshal info: %v", err) 2453 } 2454 // Send CONNECT and PING 2455 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n")) 2456 if _, err := wsc.Write(wsmsg); err != nil { 2457 t.Fatalf("Error sending message: %v", err) 2458 } 2459 // Wait for the PONG 2460 if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 2461 t.Fatalf("Expected PONG, got %s", msg) 2462 } 2463 2464 var uname string 2465 var accname string 2466 c := s.getClient(info.CID) 2467 if c != nil { 2468 c.mu.Lock() 2469 uname = c.opts.Username 2470 if c.acc != nil { 2471 accname = c.acc.GetName() 2472 } 2473 c.mu.Unlock() 2474 } 2475 if uname != certUserName { 2476 t.Fatalf("Expected username %q, got %q", certUserName, uname) 2477 } 2478 if accname != accName { 2479 t.Fatalf("Expected account %q, got %v", accName, accname) 2480 } 2481 }) 2482 } 2483 } 2484 2485 func TestWSHandshakeTimeout(t *testing.T) { 2486 o := testWSOptions() 2487 o.Websocket.HandshakeTimeout = time.Millisecond 2488 tc := &TLSConfigOpts{ 2489 CertFile: "./configs/certs/server.pem", 2490 KeyFile: "./configs/certs/key.pem", 2491 } 2492 o.Websocket.TLSConfig, _ = GenTLSConfig(tc) 2493 s := RunServer(o) 2494 defer s.Shutdown() 2495 2496 logger := &captureErrorLogger{errCh: make(chan string, 1)} 2497 s.SetLogger(logger, false, false) 2498 2499 addr := fmt.Sprintf("%s:%d", o.Websocket.Host, o.Websocket.Port) 2500 wsc, err := net.Dial("tcp", addr) 2501 if err != nil { 2502 t.Fatalf("Error creating ws connection: %v", err) 2503 } 2504 defer wsc.Close() 2505 2506 // Delay the handshake 2507 wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) 2508 time.Sleep(20 * time.Millisecond) 2509 // We expect error since the server should have cut us off 2510 if err := wsc.(*tls.Conn).Handshake(); err == nil { 2511 t.Fatal("Expected error during handshake") 2512 } 2513 2514 // Check that server logs error 2515 select { 2516 case e := <-logger.errCh: 2517 // Check that log starts with "websocket: " 2518 if !strings.HasPrefix(e, "websocket: ") { 2519 t.Fatalf("Wrong log line start: %s", e) 2520 } 2521 if !strings.Contains(e, "timeout") { 2522 t.Fatalf("Unexpected error: %v", e) 2523 } 2524 case <-time.After(time.Second): 2525 t.Fatalf("Should have timed-out") 2526 } 2527 } 2528 2529 func TestWSServerReportUpgradeFailure(t *testing.T) { 2530 o := testWSOptions() 2531 s := RunServer(o) 2532 defer s.Shutdown() 2533 2534 logger := &captureErrorLogger{errCh: make(chan string, 1)} 2535 s.SetLogger(logger, false, false) 2536 2537 addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port) 2538 req := testWSCreateValidReq() 2539 req.URL, _ = url.Parse("wss://" + addr) 2540 2541 wsc, err := net.Dial("tcp", addr) 2542 if err != nil { 2543 t.Fatalf("Error creating ws connection: %v", err) 2544 } 2545 defer wsc.Close() 2546 wsc = tls.Client(wsc, &tls.Config{InsecureSkipVerify: true}) 2547 if err := wsc.(*tls.Conn).Handshake(); err != nil { 2548 t.Fatalf("Error during handshake: %v", err) 2549 } 2550 // Remove a required field from the request to have it fail 2551 req.Header.Del("Connection") 2552 // Send the request 2553 if err := req.Write(wsc); err != nil { 2554 t.Fatalf("Error sending request: %v", err) 2555 } 2556 br := bufio.NewReader(wsc) 2557 resp, err := http.ReadResponse(br, req) 2558 if err != nil { 2559 t.Fatalf("Error reading response: %v", err) 2560 } 2561 defer resp.Body.Close() 2562 if resp.StatusCode != http.StatusBadRequest { 2563 t.Fatalf("Expected status %v, got %v", http.StatusBadRequest, resp.StatusCode) 2564 } 2565 2566 // Check that server logs error 2567 select { 2568 case e := <-logger.errCh: 2569 if !strings.Contains(e, "invalid value for header 'Connection'") { 2570 t.Fatalf("Unexpected error: %v", e) 2571 } 2572 // The client IP's local should be printed as a remote from server perspective. 2573 clientIP := wsc.LocalAddr().String() 2574 if !strings.HasPrefix(e, clientIP) { 2575 t.Fatalf("IP should have been logged, it was not: %v", e) 2576 } 2577 case <-time.After(time.Second): 2578 t.Fatalf("Should have timed-out") 2579 } 2580 } 2581 2582 func TestWSCloseMsgSendOnConnectionClose(t *testing.T) { 2583 o := testWSOptions() 2584 s := RunServer(o) 2585 defer s.Shutdown() 2586 2587 wsc, br := testWSCreateClient(t, false, false, o.Websocket.Host, o.Websocket.Port) 2588 defer wsc.Close() 2589 2590 checkClientsCount(t, s, 1) 2591 var c *client 2592 s.mu.Lock() 2593 for _, cli := range s.clients { 2594 c = cli 2595 break 2596 } 2597 s.mu.Unlock() 2598 2599 c.closeConnection(ProtocolViolation) 2600 msg := testWSReadFrame(t, br) 2601 if len(msg) < 2 { 2602 t.Fatalf("Should have 2 bytes to represent the status, got %v", msg) 2603 } 2604 if sc := int(binary.BigEndian.Uint16(msg[:2])); sc != wsCloseStatusProtocolError { 2605 t.Fatalf("Expected status to be %v, got %v", wsCloseStatusProtocolError, sc) 2606 } 2607 expectedPayload := ProtocolViolation.String() 2608 if p := string(msg[2:]); p != expectedPayload { 2609 t.Fatalf("Expected payload to be %q, got %q", expectedPayload, p) 2610 } 2611 } 2612 2613 func TestWSAdvertise(t *testing.T) { 2614 o := testWSOptions() 2615 o.Cluster.Port = 0 2616 o.HTTPPort = 0 2617 o.Websocket.Advertise = "xxx:host:yyy" 2618 s, err := NewServer(o) 2619 if err != nil { 2620 t.Fatalf("Unexpected error: %v", err) 2621 } 2622 defer s.Shutdown() 2623 l := &captureFatalLogger{fatalCh: make(chan string, 1)} 2624 s.SetLogger(l, false, false) 2625 s.Start() 2626 select { 2627 case e := <-l.fatalCh: 2628 if !strings.Contains(e, "Unable to get websocket connect URLs") { 2629 t.Fatalf("Unexpected error: %q", e) 2630 } 2631 case <-time.After(time.Second): 2632 t.Fatal("Should have failed to start") 2633 } 2634 s.Shutdown() 2635 2636 o1 := testWSOptions() 2637 o1.Websocket.Advertise = "host1:1234" 2638 s1 := RunServer(o1) 2639 defer s1.Shutdown() 2640 2641 wsc, br := testWSCreateClient(t, false, false, o1.Websocket.Host, o1.Websocket.Port) 2642 defer wsc.Close() 2643 2644 o2 := testWSOptions() 2645 o2.Websocket.Advertise = "host2:5678" 2646 o2.Routes = RoutesFromStr(fmt.Sprintf("nats://%s:%d", o1.Cluster.Host, o1.Cluster.Port)) 2647 s2 := RunServer(o2) 2648 defer s2.Shutdown() 2649 2650 checkInfo := func(expected []string) { 2651 t.Helper() 2652 infob := testWSReadFrame(t, br) 2653 info := &Info{} 2654 json.Unmarshal(infob[5:], info) 2655 if n := len(info.ClientConnectURLs); n != len(expected) { 2656 t.Fatalf("Unexpected info: %+v", info) 2657 } 2658 good := 0 2659 for _, u := range info.ClientConnectURLs { 2660 for _, eu := range expected { 2661 if u == eu { 2662 good++ 2663 } 2664 } 2665 } 2666 if good != len(expected) { 2667 t.Fatalf("Unexpected connect urls: %q", info.ClientConnectURLs) 2668 } 2669 } 2670 checkInfo([]string{"host1:1234", "host2:5678"}) 2671 2672 // Now shutdown s2 and expect another INFO 2673 s2.Shutdown() 2674 checkInfo([]string{"host1:1234"}) 2675 2676 // Restart with another advertise and check that it gets updated 2677 o2.Websocket.Advertise = "host3:9012" 2678 s2 = RunServer(o2) 2679 defer s2.Shutdown() 2680 checkInfo([]string{"host1:1234", "host3:9012"}) 2681 } 2682 2683 func TestWSFrameOutbound(t *testing.T) { 2684 for _, test := range []struct { 2685 name string 2686 maskingWrite bool 2687 }{ 2688 {"no write masking", false}, 2689 {"write masking", true}, 2690 } { 2691 t.Run(test.name, func(t *testing.T) { 2692 c, _, _ := testWSSetupForRead() 2693 c.ws.maskwrite = test.maskingWrite 2694 2695 getKey := func(buf []byte) []byte { 2696 return buf[len(buf)-4:] 2697 } 2698 2699 var bufs net.Buffers 2700 bufs = append(bufs, []byte("this ")) 2701 bufs = append(bufs, []byte("is ")) 2702 bufs = append(bufs, []byte("a ")) 2703 bufs = append(bufs, []byte("set ")) 2704 bufs = append(bufs, []byte("of ")) 2705 bufs = append(bufs, []byte("buffers")) 2706 en := 2 2707 for _, b := range bufs { 2708 en += len(b) 2709 } 2710 if test.maskingWrite { 2711 en += 4 2712 } 2713 c.mu.Lock() 2714 c.out.nb = bufs 2715 res, n := c.collapsePtoNB() 2716 c.mu.Unlock() 2717 if n != int64(en) { 2718 t.Fatalf("Expected size to be %v, got %v", en, n) 2719 } 2720 if eb := 1 + len(bufs); eb != len(res) { 2721 t.Fatalf("Expected %v buffers, got %v", eb, len(res)) 2722 } 2723 var ob []byte 2724 for i := 1; i < len(res); i++ { 2725 ob = append(ob, res[i]...) 2726 } 2727 if test.maskingWrite { 2728 wsMaskBuf(getKey(res[0]), ob) 2729 } 2730 if !bytes.Equal(ob, []byte("this is a set of buffers")) { 2731 t.Fatalf("Unexpected outbound: %q", ob) 2732 } 2733 2734 bufs = nil 2735 c.out.pb = 0 2736 c.ws.fs = 0 2737 c.ws.frames = nil 2738 c.ws.browser = true 2739 bufs = append(bufs, []byte("some smaller ")) 2740 bufs = append(bufs, []byte("buffers")) 2741 bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+10)) 2742 bufs = append(bufs, []byte("then some more")) 2743 en = 2 + len(bufs[0]) + len(bufs[1]) 2744 en += 4 + len(bufs[2]) - 10 2745 en += 2 + len(bufs[3]) + 10 2746 c.mu.Lock() 2747 c.out.nb = bufs 2748 res, n = c.collapsePtoNB() 2749 c.mu.Unlock() 2750 if test.maskingWrite { 2751 en += 3 * 4 2752 } 2753 if n != int64(en) { 2754 t.Fatalf("Expected size to be %v, got %v", en, n) 2755 } 2756 if len(res) != 8 { 2757 t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) 2758 } 2759 if len(res[4]) != wsFrameSizeForBrowsers { 2760 t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) 2761 } 2762 if len(res[6]) != 10 { 2763 t.Fatalf("Frame 6 should have the partial of 10 bytes, got %v", len(res[6])) 2764 } 2765 if test.maskingWrite { 2766 b := &bytes.Buffer{} 2767 key := getKey(res[0]) 2768 b.Write(res[1]) 2769 b.Write(res[2]) 2770 ud := b.Bytes() 2771 wsMaskBuf(key, ud) 2772 if string(ud) != "some smaller buffers" { 2773 t.Fatalf("Unexpected result: %q", ud) 2774 } 2775 2776 b.Reset() 2777 key = getKey(res[3]) 2778 b.Write(res[4]) 2779 ud = b.Bytes() 2780 wsMaskBuf(key, ud) 2781 for i := 0; i < len(ud); i++ { 2782 if ud[i] != 0 { 2783 t.Fatalf("Unexpected result: %v", ud) 2784 } 2785 } 2786 2787 b.Reset() 2788 key = getKey(res[5]) 2789 b.Write(res[6]) 2790 b.Write(res[7]) 2791 ud = b.Bytes() 2792 wsMaskBuf(key, ud) 2793 for i := 0; i < len(ud[:10]); i++ { 2794 if ud[i] != 0 { 2795 t.Fatalf("Unexpected result: %v", ud[:10]) 2796 } 2797 } 2798 if string(ud[10:]) != "then some more" { 2799 t.Fatalf("Unexpected result: %q", ud[10:]) 2800 } 2801 } 2802 2803 bufs = nil 2804 c.out.pb = 0 2805 c.ws.fs = 0 2806 c.ws.frames = nil 2807 c.ws.browser = true 2808 bufs = append(bufs, []byte("some smaller ")) 2809 bufs = append(bufs, []byte("buffers")) 2810 // Have one of the exact max size 2811 bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers)) 2812 bufs = append(bufs, []byte("then some more")) 2813 en = 2 + len(bufs[0]) + len(bufs[1]) 2814 en += 4 + len(bufs[2]) 2815 en += 2 + len(bufs[3]) 2816 c.mu.Lock() 2817 c.out.nb = bufs 2818 res, n = c.collapsePtoNB() 2819 c.mu.Unlock() 2820 if test.maskingWrite { 2821 en += 3 * 4 2822 } 2823 if n != int64(en) { 2824 t.Fatalf("Expected size to be %v, got %v", en, n) 2825 } 2826 if len(res) != 7 { 2827 t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) 2828 } 2829 if len(res[4]) != wsFrameSizeForBrowsers { 2830 t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) 2831 } 2832 if test.maskingWrite { 2833 key := getKey(res[5]) 2834 wsMaskBuf(key, res[6]) 2835 } 2836 if string(res[6]) != "then some more" { 2837 t.Fatalf("Frame 6 incorrect: %q", res[6]) 2838 } 2839 2840 bufs = nil 2841 c.out.pb = 0 2842 c.ws.fs = 0 2843 c.ws.frames = nil 2844 c.ws.browser = true 2845 bufs = append(bufs, []byte("some smaller ")) 2846 bufs = append(bufs, []byte("buffers")) 2847 // Have one of the exact max size, and last in the list 2848 bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers)) 2849 en = 2 + len(bufs[0]) + len(bufs[1]) 2850 en += 4 + len(bufs[2]) 2851 c.mu.Lock() 2852 c.out.nb = bufs 2853 res, n = c.collapsePtoNB() 2854 c.mu.Unlock() 2855 if test.maskingWrite { 2856 en += 2 * 4 2857 } 2858 if n != int64(en) { 2859 t.Fatalf("Expected size to be %v, got %v", en, n) 2860 } 2861 if len(res) != 5 { 2862 t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) 2863 } 2864 if len(res[4]) != wsFrameSizeForBrowsers { 2865 t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) 2866 } 2867 2868 bufs = nil 2869 c.out.pb = 0 2870 c.ws.fs = 0 2871 c.ws.frames = nil 2872 c.ws.browser = true 2873 bufs = append(bufs, []byte("some smaller buffer")) 2874 bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers-5)) 2875 bufs = append(bufs, []byte("then some more")) 2876 en = 2 + len(bufs[0]) 2877 en += 4 + len(bufs[1]) 2878 en += 2 + len(bufs[2]) 2879 c.mu.Lock() 2880 c.out.nb = bufs 2881 res, n = c.collapsePtoNB() 2882 c.mu.Unlock() 2883 if test.maskingWrite { 2884 en += 3 * 4 2885 } 2886 if n != int64(en) { 2887 t.Fatalf("Expected size to be %v, got %v", en, n) 2888 } 2889 if len(res) != 6 { 2890 t.Fatalf("Unexpected number of outbound buffers: %v", len(res)) 2891 } 2892 if len(res[3]) != wsFrameSizeForBrowsers-5 { 2893 t.Fatalf("Big frame should have been limited to %v, got %v", wsFrameSizeForBrowsers, len(res[4])) 2894 } 2895 if test.maskingWrite { 2896 key := getKey(res[4]) 2897 wsMaskBuf(key, res[5]) 2898 } 2899 if string(res[5]) != "then some more" { 2900 t.Fatalf("Frame 6 incorrect %q", res[5]) 2901 } 2902 2903 bufs = nil 2904 c.out.pb = 0 2905 c.ws.fs = 0 2906 c.ws.frames = nil 2907 c.ws.browser = true 2908 bufs = append(bufs, make([]byte, wsFrameSizeForBrowsers+100)) 2909 c.mu.Lock() 2910 c.out.nb = bufs 2911 res, _ = c.collapsePtoNB() 2912 c.mu.Unlock() 2913 if len(res) != 4 { 2914 t.Fatalf("Unexpected number of frames: %v", len(res)) 2915 } 2916 }) 2917 } 2918 } 2919 2920 func TestWSWebrowserClient(t *testing.T) { 2921 o := testWSOptions() 2922 s := RunServer(o) 2923 defer s.Shutdown() 2924 2925 wsc, br := testWSCreateClient(t, false, true, o.Websocket.Host, o.Websocket.Port) 2926 defer wsc.Close() 2927 2928 checkClientsCount(t, s, 1) 2929 var c *client 2930 s.mu.Lock() 2931 for _, cli := range s.clients { 2932 c = cli 2933 break 2934 } 2935 s.mu.Unlock() 2936 2937 proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("SUB foo 1\r\nPING\r\n")) 2938 wsc.Write(proto) 2939 if res := testWSReadFrame(t, br); !bytes.Equal(res, []byte(pongProto)) { 2940 t.Fatalf("Expected PONG back") 2941 } 2942 2943 c.mu.Lock() 2944 ok := c.isWebsocket() && c.ws.browser == true 2945 c.mu.Unlock() 2946 if !ok { 2947 t.Fatalf("Client is not marked as webrowser client") 2948 } 2949 2950 nc := natsConnect(t, s.ClientURL()) 2951 defer nc.Close() 2952 2953 // Send a big message and check that it is received in smaller frames 2954 psize := 204813 2955 nc.Publish("foo", make([]byte, psize)) 2956 nc.Flush() 2957 2958 rsize := psize + len(fmt.Sprintf("MSG foo %d\r\n\r\n", psize)) 2959 nframes := 0 2960 for total := 0; total < rsize; nframes++ { 2961 res := testWSReadFrame(t, br) 2962 total += len(res) 2963 } 2964 if expected := psize / wsFrameSizeForBrowsers; expected > nframes { 2965 t.Fatalf("Expected %v frames, got %v", expected, nframes) 2966 } 2967 } 2968 2969 type testWSWrappedConn struct { 2970 net.Conn 2971 mu sync.RWMutex 2972 buf *bytes.Buffer 2973 partial bool 2974 } 2975 2976 func (wc *testWSWrappedConn) Write(p []byte) (int, error) { 2977 wc.mu.Lock() 2978 defer wc.mu.Unlock() 2979 var err error 2980 n := len(p) 2981 if wc.partial && n > 10 { 2982 n = 10 2983 err = io.ErrShortWrite 2984 } 2985 p = p[:n] 2986 wc.buf.Write(p) 2987 wc.Conn.Write(p) 2988 return n, err 2989 } 2990 2991 func TestWSCompressionBasic(t *testing.T) { 2992 payload := "This is the content of a message that will be compresseddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddddd." 2993 msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload) 2994 cbuf := &bytes.Buffer{} 2995 compressor, err := flate.NewWriter(cbuf, flate.BestSpeed) 2996 require_NoError(t, err) 2997 compressor.Write([]byte(msgProto)) 2998 compressor.Flush() 2999 compressed := cbuf.Bytes() 3000 // The last 4 bytes are dropped 3001 compressed = compressed[:len(compressed)-4] 3002 3003 o := testWSOptions() 3004 o.Websocket.Compression = true 3005 s := RunServer(o) 3006 defer s.Shutdown() 3007 3008 c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port) 3009 defer c.Close() 3010 3011 proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n")) 3012 c.Write(proto) 3013 l := testWSReadFrame(t, br) 3014 if !bytes.Equal(l, []byte(pongProto)) { 3015 t.Fatalf("Expected PONG, got %q", l) 3016 } 3017 3018 var wc *testWSWrappedConn 3019 s.mu.RLock() 3020 for _, c := range s.clients { 3021 c.mu.Lock() 3022 wc = &testWSWrappedConn{Conn: c.nc, buf: &bytes.Buffer{}} 3023 c.nc = wc 3024 c.mu.Unlock() 3025 } 3026 s.mu.RUnlock() 3027 3028 nc := natsConnect(t, s.ClientURL()) 3029 defer nc.Close() 3030 natsPub(t, nc, "foo", []byte(payload)) 3031 3032 res := &bytes.Buffer{} 3033 for total := 0; total < len(msgProto); { 3034 l := testWSReadFrame(t, br) 3035 n, _ := res.Write(l) 3036 total += n 3037 } 3038 if !bytes.Equal([]byte(msgProto), res.Bytes()) { 3039 t.Fatalf("Unexpected result: %q", res) 3040 } 3041 3042 // Now check the wrapped connection buffer to check that data was actually compressed. 3043 wc.mu.RLock() 3044 res = wc.buf 3045 wc.mu.RUnlock() 3046 if bytes.Contains(res.Bytes(), []byte(payload)) { 3047 t.Fatalf("Looks like frame was not compressed: %q", res.Bytes()) 3048 } 3049 header := res.Bytes()[:2] 3050 body := res.Bytes()[2:] 3051 expectedB0 := byte(wsBinaryMessage) | wsFinalBit | wsRsv1Bit 3052 expectedPS := len(compressed) 3053 expectedB1 := byte(expectedPS) 3054 3055 if b := header[0]; b != expectedB0 { 3056 t.Fatalf("Expected first byte to be %v, got %v", expectedB0, b) 3057 } 3058 if b := header[1]; b != expectedB1 { 3059 t.Fatalf("Expected second byte to be %v, got %v", expectedB1, b) 3060 } 3061 if len(body) != expectedPS { 3062 t.Fatalf("Expected payload length to be %v, got %v", expectedPS, len(body)) 3063 } 3064 if !bytes.Equal(body, compressed) { 3065 t.Fatalf("Unexpected compress body: %q", body) 3066 } 3067 3068 wc.mu.Lock() 3069 wc.buf.Reset() 3070 wc.mu.Unlock() 3071 3072 payload = "small" 3073 natsPub(t, nc, "foo", []byte(payload)) 3074 msgProto = fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload) 3075 res = &bytes.Buffer{} 3076 for total := 0; total < len(msgProto); { 3077 l := testWSReadFrame(t, br) 3078 n, _ := res.Write(l) 3079 total += n 3080 } 3081 if !bytes.Equal([]byte(msgProto), res.Bytes()) { 3082 t.Fatalf("Unexpected result: %q", res) 3083 } 3084 wc.mu.RLock() 3085 res = wc.buf 3086 wc.mu.RUnlock() 3087 if !bytes.HasSuffix(res.Bytes(), []byte(msgProto)) { 3088 t.Fatalf("Looks like frame was compressed: %q", res.Bytes()) 3089 } 3090 } 3091 3092 func TestWSCompressionWithPartialWrite(t *testing.T) { 3093 payload := "This is the content of a message that will be compresseddddddddddddddddddddd." 3094 msgProto := fmt.Sprintf("MSG foo 1 %d\r\n%s\r\n", len(payload), payload) 3095 3096 o := testWSOptions() 3097 o.Websocket.Compression = true 3098 s := RunServer(o) 3099 defer s.Shutdown() 3100 3101 c, br := testWSCreateClient(t, true, false, o.Websocket.Host, o.Websocket.Port) 3102 defer c.Close() 3103 3104 proto := testWSCreateClientMsg(wsBinaryMessage, 1, true, true, []byte("SUB foo 1\r\nPING\r\n")) 3105 c.Write(proto) 3106 l := testWSReadFrame(t, br) 3107 if !bytes.Equal(l, []byte(pongProto)) { 3108 t.Fatalf("Expected PONG, got %q", l) 3109 } 3110 3111 pingPayload := []byte("my ping") 3112 pingFromWSClient := testWSCreateClientMsg(wsPingMessage, 1, true, false, pingPayload) 3113 3114 var wc *testWSWrappedConn 3115 var ws *client 3116 s.mu.Lock() 3117 for _, c := range s.clients { 3118 ws = c 3119 c.mu.Lock() 3120 wc = &testWSWrappedConn{ 3121 Conn: c.nc, 3122 buf: &bytes.Buffer{}, 3123 } 3124 c.nc = wc 3125 c.mu.Unlock() 3126 break 3127 } 3128 s.mu.Unlock() 3129 3130 wc.mu.Lock() 3131 wc.partial = true 3132 wc.mu.Unlock() 3133 3134 nc := natsConnect(t, s.ClientURL()) 3135 defer nc.Close() 3136 3137 expected := &bytes.Buffer{} 3138 for i := 0; i < 10; i++ { 3139 if i > 0 { 3140 time.Sleep(10 * time.Millisecond) 3141 } 3142 expected.Write([]byte(msgProto)) 3143 natsPub(t, nc, "foo", []byte(payload)) 3144 if i == 1 { 3145 c.Write(pingFromWSClient) 3146 } 3147 } 3148 3149 var gotPingResponse bool 3150 res := &bytes.Buffer{} 3151 for total := 0; total < 10*len(msgProto); { 3152 l := testWSReadFrame(t, br) 3153 if bytes.Equal(l, pingPayload) { 3154 gotPingResponse = true 3155 } else { 3156 n, _ := res.Write(l) 3157 total += n 3158 } 3159 } 3160 if !bytes.Equal(expected.Bytes(), res.Bytes()) { 3161 t.Fatalf("Unexpected result: %q", res) 3162 } 3163 if !gotPingResponse { 3164 t.Fatal("Did not get the ping response") 3165 } 3166 3167 checkFor(t, time.Second, 15*time.Millisecond, func() error { 3168 ws.mu.Lock() 3169 pb := ws.out.pb 3170 wf := ws.ws.frames 3171 fs := ws.ws.fs 3172 ws.mu.Unlock() 3173 if pb != 0 || len(wf) != 0 || fs != 0 { 3174 return fmt.Errorf("Expected pb, wf and fs to be 0, got %v, %v, %v", pb, wf, fs) 3175 } 3176 return nil 3177 }) 3178 } 3179 3180 func TestWSCompressionFrameSizeLimit(t *testing.T) { 3181 for _, test := range []struct { 3182 name string 3183 maskWrite bool 3184 noLimit bool 3185 }{ 3186 {"no write masking", false, false}, 3187 {"write masking", true, false}, 3188 } { 3189 t.Run(test.name, func(t *testing.T) { 3190 opts := testWSOptions() 3191 opts.MaxPending = MAX_PENDING_SIZE 3192 s := &Server{opts: opts} 3193 c := &client{srv: s, ws: &websocket{compress: true, browser: true, nocompfrag: test.noLimit, maskwrite: test.maskWrite}} 3194 c.initClient() 3195 3196 uncompressedPayload := make([]byte, 2*wsFrameSizeForBrowsers) 3197 for i := 0; i < len(uncompressedPayload); i++ { 3198 uncompressedPayload[i] = byte(rand.Intn(256)) 3199 } 3200 3201 c.mu.Lock() 3202 c.out.nb = append(net.Buffers(nil), uncompressedPayload) 3203 nb, _ := c.collapsePtoNB() 3204 c.mu.Unlock() 3205 3206 if test.noLimit && len(nb) != 2 { 3207 t.Fatalf("There should be only 2 buffers, the header and payload, got %v", len(nb)) 3208 } 3209 3210 bb := &bytes.Buffer{} 3211 var key []byte 3212 for i, b := range nb { 3213 if !test.noLimit { 3214 // frame header buffer are always very small. The payload should not be more 3215 // than 10 bytes since that is what we passed as the limit. 3216 if len(b) > wsFrameSizeForBrowsers { 3217 t.Fatalf("Frame size too big: %v (%q)", len(b), b) 3218 } 3219 } 3220 if test.maskWrite { 3221 if i%2 == 0 { 3222 key = b[len(b)-4:] 3223 } else { 3224 wsMaskBuf(key, b) 3225 } 3226 } 3227 // Check frame headers for the proper formatting. 3228 if i%2 == 0 { 3229 // Only the first frame should have the compress bit set. 3230 if b[0]&wsRsv1Bit != 0 { 3231 if i > 0 { 3232 t.Fatalf("Compressed bit should not be in continuation frame") 3233 } 3234 } else if i == 0 { 3235 t.Fatalf("Compressed bit missing") 3236 } 3237 } else { 3238 if test.noLimit { 3239 // Since the payload is likely not well compressed, we are expecting 3240 // the length to be > wsFrameSizeForBrowsers 3241 if len(b) <= wsFrameSizeForBrowsers { 3242 t.Fatalf("Expected frame to be bigger, got %v", len(b)) 3243 } 3244 } 3245 // Collect the payload 3246 bb.Write(b) 3247 } 3248 } 3249 buf := bb.Bytes() 3250 buf = append(buf, 0x00, 0x00, 0xff, 0xff, 0x01, 0x00, 0x00, 0xff, 0xff) 3251 dbr := bytes.NewBuffer(buf) 3252 d := flate.NewReader(dbr) 3253 uncompressed, err := io.ReadAll(d) 3254 if err != nil { 3255 t.Fatalf("Error reading frame: %v", err) 3256 } 3257 if !bytes.Equal(uncompressed, uncompressedPayload) { 3258 t.Fatalf("Unexpected uncomressed data: %q", uncompressed) 3259 } 3260 }) 3261 } 3262 } 3263 3264 func TestWSBasicAuth(t *testing.T) { 3265 for _, test := range []struct { 3266 name string 3267 opts func() *Options 3268 user string 3269 pass string 3270 err string 3271 cookies []string 3272 }{ 3273 { 3274 "top level auth, no override, wrong u/p", 3275 func() *Options { 3276 o := testWSOptions() 3277 o.Username = "normal" 3278 o.Password = "client" 3279 return o 3280 }, 3281 "websocket", "client", "-ERR 'Authorization Violation'", 3282 nil, 3283 }, 3284 { 3285 "top level auth, no override, correct u/p", 3286 func() *Options { 3287 o := testWSOptions() 3288 o.Username = "normal" 3289 o.Password = "client" 3290 return o 3291 }, 3292 "normal", "client", "", 3293 nil, 3294 }, 3295 { 3296 "no top level auth, ws auth, wrong u/p", 3297 func() *Options { 3298 o := testWSOptions() 3299 o.Websocket.Username = "websocket" 3300 o.Websocket.Password = "client" 3301 return o 3302 }, 3303 "normal", "client", "-ERR 'Authorization Violation'", 3304 nil, 3305 }, 3306 { 3307 "no top level auth, ws auth, correct u/p", 3308 func() *Options { 3309 o := testWSOptions() 3310 o.Websocket.Username = "websocket" 3311 o.Websocket.Password = "client" 3312 return o 3313 }, 3314 "websocket", "client", "", 3315 nil, 3316 }, 3317 { 3318 "top level auth, ws override, wrong u/p", 3319 func() *Options { 3320 o := testWSOptions() 3321 o.Username = "normal" 3322 o.Password = "client" 3323 o.Websocket.Username = "websocket" 3324 o.Websocket.Password = "client" 3325 return o 3326 }, 3327 "normal", "client", "-ERR 'Authorization Violation'", 3328 nil, 3329 }, 3330 { 3331 "top level auth, ws override, correct u/p", 3332 func() *Options { 3333 o := testWSOptions() 3334 o.Username = "normal" 3335 o.Password = "client" 3336 o.Websocket.Username = "websocket" 3337 o.Websocket.Password = "client" 3338 return o 3339 }, 3340 "websocket", "client", "", 3341 nil, 3342 }, 3343 { 3344 "username/password from cookies", 3345 func() *Options { 3346 o := testWSOptions() 3347 o.Websocket.UsernameCookie = "un" 3348 o.Websocket.PasswordCookie = "pw" 3349 o.Username = "me" 3350 o.Password = "s3cr3t!" 3351 return o 3352 }, 3353 "", "", "", 3354 []string{"un=me", "pw=s3cr3t!"}, 3355 }, 3356 { 3357 "bad username/ good password from cookies", 3358 func() *Options { 3359 o := testWSOptions() 3360 o.Websocket.UsernameCookie = "un" 3361 o.Websocket.PasswordCookie = "pw" 3362 o.Username = "me" 3363 o.Password = "s3cr3t!" 3364 return o 3365 }, 3366 "", "", "-ERR 'Authorization Violation", 3367 []string{"un=m", "pw=s3cr3t!"}, 3368 }, 3369 { 3370 "good username/ bad password from cookies", 3371 func() *Options { 3372 o := testWSOptions() 3373 o.Websocket.UsernameCookie = "un" 3374 o.Websocket.PasswordCookie = "pw" 3375 o.Username = "me" 3376 o.Password = "s3cr3t!" 3377 return o 3378 }, 3379 "", "", "-ERR 'Authorization Violation", 3380 []string{"un=me", "pw=hi!"}, 3381 }, 3382 { 3383 "token from cookie", 3384 func() *Options { 3385 o := testWSOptions() 3386 o.Websocket.TokenCookie = "tok" 3387 o.Authorization = "l3tm31n!" 3388 return o 3389 }, 3390 "", "", "", 3391 []string{"tok=l3tm31n!"}, 3392 }, 3393 { 3394 "bad token from cookie", 3395 func() *Options { 3396 o := testWSOptions() 3397 o.Websocket.TokenCookie = "tok" 3398 o.Authorization = "l3tm31n!" 3399 return o 3400 }, 3401 "", "", "-ERR 'Authorization Violation", 3402 []string{"tok=hello!"}, 3403 }, 3404 } { 3405 t.Run(test.name, func(t *testing.T) { 3406 o := test.opts() 3407 s := RunServer(o) 3408 defer s.Shutdown() 3409 3410 wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port, test.cookies...) 3411 defer wsc.Close() 3412 3413 connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", 3414 test.user, test.pass) 3415 3416 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3417 if _, err := wsc.Write(wsmsg); err != nil { 3418 t.Fatalf("Error sending message: %v", err) 3419 } 3420 msg := testWSReadFrame(t, br) 3421 if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3422 t.Fatalf("Expected to receive PONG, got %q", msg) 3423 } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { 3424 t.Fatalf("Expected to receive %q, got %q", test.err, msg) 3425 } 3426 }) 3427 } 3428 } 3429 3430 func TestWSAuthTimeout(t *testing.T) { 3431 for _, test := range []struct { 3432 name string 3433 at float64 3434 wat float64 3435 err string 3436 }{ 3437 {"use top-level auth timeout", 10.0, 0.0, ""}, 3438 {"use websocket auth timeout", 10.0, 0.05, "-ERR 'Authentication Timeout'"}, 3439 } { 3440 t.Run(test.name, func(t *testing.T) { 3441 o := testWSOptions() 3442 o.AuthTimeout = test.at 3443 o.Websocket.Username = "websocket" 3444 o.Websocket.Password = "client" 3445 o.Websocket.AuthTimeout = test.wat 3446 s := RunServer(o) 3447 defer s.Shutdown() 3448 3449 wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) 3450 defer wsc.Close() 3451 3452 var info serverInfo 3453 json.Unmarshal([]byte(l[5:]), &info) 3454 // Make sure that we are told that auth is required. 3455 if !info.AuthRequired { 3456 t.Fatalf("Expected auth required, was not: %q", l) 3457 } 3458 start := time.Now() 3459 // Wait before sending connect 3460 time.Sleep(100 * time.Millisecond) 3461 connectProto := "CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"websocket\",\"pass\":\"client\"}\r\nPING\r\n" 3462 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3463 if _, err := wsc.Write(wsmsg); err != nil { 3464 t.Fatalf("Error sending message: %v", err) 3465 } 3466 msg := testWSReadFrame(t, br) 3467 if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { 3468 t.Fatalf("Expected to receive %q error, got %q", test.err, msg) 3469 } else if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3470 t.Fatalf("Unexpected error: %q", msg) 3471 } 3472 if dur := time.Since(start); dur > time.Second { 3473 t.Fatalf("Too long to get timeout error: %v", dur) 3474 } 3475 }) 3476 } 3477 } 3478 3479 func TestWSTokenAuth(t *testing.T) { 3480 for _, test := range []struct { 3481 name string 3482 opts func() *Options 3483 token string 3484 err string 3485 }{ 3486 { 3487 "top level auth, no override, wrong token", 3488 func() *Options { 3489 o := testWSOptions() 3490 o.Authorization = "goodtoken" 3491 return o 3492 }, 3493 "badtoken", "-ERR 'Authorization Violation'", 3494 }, 3495 { 3496 "top level auth, no override, correct token", 3497 func() *Options { 3498 o := testWSOptions() 3499 o.Authorization = "goodtoken" 3500 return o 3501 }, 3502 "goodtoken", "", 3503 }, 3504 { 3505 "no top level auth, ws auth, wrong token", 3506 func() *Options { 3507 o := testWSOptions() 3508 o.Websocket.Token = "goodtoken" 3509 return o 3510 }, 3511 "badtoken", "-ERR 'Authorization Violation'", 3512 }, 3513 { 3514 "no top level auth, ws auth, correct token", 3515 func() *Options { 3516 o := testWSOptions() 3517 o.Websocket.Token = "goodtoken" 3518 return o 3519 }, 3520 "goodtoken", "", 3521 }, 3522 { 3523 "top level auth, ws override, wrong token", 3524 func() *Options { 3525 o := testWSOptions() 3526 o.Authorization = "clienttoken" 3527 o.Websocket.Token = "websockettoken" 3528 return o 3529 }, 3530 "clienttoken", "-ERR 'Authorization Violation'", 3531 }, 3532 { 3533 "top level auth, ws override, correct token", 3534 func() *Options { 3535 o := testWSOptions() 3536 o.Authorization = "clienttoken" 3537 o.Websocket.Token = "websockettoken" 3538 return o 3539 }, 3540 "websockettoken", "", 3541 }, 3542 } { 3543 t.Run(test.name, func(t *testing.T) { 3544 o := test.opts() 3545 s := RunServer(o) 3546 defer s.Shutdown() 3547 3548 wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) 3549 defer wsc.Close() 3550 3551 connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"auth_token\":\"%s\"}\r\nPING\r\n", 3552 test.token) 3553 3554 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3555 if _, err := wsc.Write(wsmsg); err != nil { 3556 t.Fatalf("Error sending message: %v", err) 3557 } 3558 msg := testWSReadFrame(t, br) 3559 if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3560 t.Fatalf("Expected to receive PONG, got %q", msg) 3561 } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { 3562 t.Fatalf("Expected to receive %q, got %q", test.err, msg) 3563 } 3564 }) 3565 } 3566 } 3567 3568 func TestWSBindToProperAccount(t *testing.T) { 3569 conf := createConfFile(t, []byte(fmt.Sprintf(` 3570 listen: "127.0.0.1:-1" 3571 accounts { 3572 a { 3573 users [ 3574 {user: a, password: pwd, allowed_connection_types: ["%s", "%s"]} 3575 ] 3576 } 3577 b { 3578 users [ 3579 {user: b, password: pwd} 3580 ] 3581 } 3582 } 3583 websocket { 3584 listen: "127.0.0.1:-1" 3585 no_tls: true 3586 } 3587 `, jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)))) // on purpose use lower case to ensure that it is converted. 3588 s, o := RunServerWithConfig(conf) 3589 defer s.Shutdown() 3590 3591 nc := natsConnect(t, fmt.Sprintf("nats://a:pwd@127.0.0.1:%d", o.Port)) 3592 defer nc.Close() 3593 3594 sub := natsSubSync(t, nc, "foo") 3595 3596 wsc, br, _ := testNewWSClient(t, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port, noTLS: true}) 3597 // Send CONNECT and PING 3598 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, 3599 []byte(fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", "a", "pwd"))) 3600 if _, err := wsc.Write(wsmsg); err != nil { 3601 t.Fatalf("Error sending message: %v", err) 3602 } 3603 // Wait for the PONG 3604 if msg := testWSReadFrame(t, br); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3605 t.Fatalf("Expected PONG, got %s", msg) 3606 } 3607 3608 wsmsg = testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("PUB foo 7\r\nfrom ws\r\n")) 3609 if _, err := wsc.Write(wsmsg); err != nil { 3610 t.Fatalf("Error sending message: %v", err) 3611 } 3612 3613 natsNexMsg(t, sub, time.Second) 3614 } 3615 3616 func TestWSUsersAuth(t *testing.T) { 3617 users := []*User{{Username: "user", Password: "pwd"}} 3618 for _, test := range []struct { 3619 name string 3620 opts func() *Options 3621 user string 3622 pass string 3623 err string 3624 }{ 3625 { 3626 "no filtering, wrong user", 3627 func() *Options { 3628 o := testWSOptions() 3629 o.Users = users 3630 return o 3631 }, 3632 "wronguser", "pwd", "-ERR 'Authorization Violation'", 3633 }, 3634 { 3635 "no filtering, correct user", 3636 func() *Options { 3637 o := testWSOptions() 3638 o.Users = users 3639 return o 3640 }, 3641 "user", "pwd", "", 3642 }, 3643 { 3644 "filering, user not allowed", 3645 func() *Options { 3646 o := testWSOptions() 3647 o.Users = users 3648 // Only allowed for regular clients 3649 o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}) 3650 return o 3651 }, 3652 "user", "pwd", "-ERR 'Authorization Violation'", 3653 }, 3654 { 3655 "filtering, user allowed", 3656 func() *Options { 3657 o := testWSOptions() 3658 o.Users = users 3659 o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}) 3660 return o 3661 }, 3662 "user", "pwd", "", 3663 }, 3664 { 3665 "filtering, wrong password", 3666 func() *Options { 3667 o := testWSOptions() 3668 o.Users = users 3669 o.Users[0].AllowedConnectionTypes = testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}) 3670 return o 3671 }, 3672 "user", "badpassword", "-ERR 'Authorization Violation'", 3673 }, 3674 } { 3675 t.Run(test.name, func(t *testing.T) { 3676 o := test.opts() 3677 s := RunServer(o) 3678 defer s.Shutdown() 3679 3680 wsc, br, _ := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) 3681 defer wsc.Close() 3682 3683 connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"%s\"}\r\nPING\r\n", 3684 test.user, test.pass) 3685 3686 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3687 if _, err := wsc.Write(wsmsg); err != nil { 3688 t.Fatalf("Error sending message: %v", err) 3689 } 3690 msg := testWSReadFrame(t, br) 3691 if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3692 t.Fatalf("Expected to receive PONG, got %q", msg) 3693 } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { 3694 t.Fatalf("Expected to receive %q, got %q", test.err, msg) 3695 } 3696 }) 3697 } 3698 } 3699 3700 func TestWSNoAuthUserValidation(t *testing.T) { 3701 o := testWSOptions() 3702 o.Users = []*User{{Username: "user", Password: "pwd"}} 3703 // Should fail because it is not part of o.Users. 3704 o.Websocket.NoAuthUser = "notfound" 3705 if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { 3706 t.Fatalf("Expected error saying not present as user, got %v", err) 3707 } 3708 // Set a valid no auth user for global options, but still should fail because 3709 // of o.Websocket.NoAuthUser 3710 o.NoAuthUser = "user" 3711 o.Websocket.NoAuthUser = "notfound" 3712 if _, err := NewServer(o); err == nil || !strings.Contains(err.Error(), "not present as user") { 3713 t.Fatalf("Expected error saying not present as user, got %v", err) 3714 } 3715 } 3716 3717 func TestWSNoAuthUser(t *testing.T) { 3718 for _, test := range []struct { 3719 name string 3720 override bool 3721 useAuth bool 3722 expectedUser string 3723 expectedAcc string 3724 }{ 3725 {"no override, no user provided", false, false, "noauth", "normal"}, 3726 {"no override, user povided", false, true, "user", "normal"}, 3727 {"override, no user provided", true, false, "wsnoauth", "websocket"}, 3728 {"override, user provided", true, true, "wsuser", "websocket"}, 3729 } { 3730 t.Run(test.name, func(t *testing.T) { 3731 o := testWSOptions() 3732 normalAcc := NewAccount("normal") 3733 websocketAcc := NewAccount("websocket") 3734 o.Accounts = []*Account{normalAcc, websocketAcc} 3735 o.Users = []*User{ 3736 {Username: "noauth", Password: "pwd", Account: normalAcc}, 3737 {Username: "user", Password: "pwd", Account: normalAcc}, 3738 {Username: "wsnoauth", Password: "pwd", Account: websocketAcc}, 3739 {Username: "wsuser", Password: "pwd", Account: websocketAcc}, 3740 } 3741 o.NoAuthUser = "noauth" 3742 if test.override { 3743 o.Websocket.NoAuthUser = "wsnoauth" 3744 } 3745 s := RunServer(o) 3746 defer s.Shutdown() 3747 3748 wsc, br, l := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) 3749 defer wsc.Close() 3750 3751 var info serverInfo 3752 json.Unmarshal([]byte(l[5:]), &info) 3753 3754 var connectProto string 3755 if test.useAuth { 3756 connectProto = fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"user\":\"%s\",\"pass\":\"pwd\"}\r\nPING\r\n", 3757 test.expectedUser) 3758 } else { 3759 connectProto = "CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n" 3760 } 3761 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3762 if _, err := wsc.Write(wsmsg); err != nil { 3763 t.Fatalf("Error sending message: %v", err) 3764 } 3765 msg := testWSReadFrame(t, br) 3766 if !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3767 t.Fatalf("Unexpected error: %q", msg) 3768 } 3769 3770 c := s.getClient(info.CID) 3771 c.mu.Lock() 3772 uname := c.opts.Username 3773 aname := c.acc.GetName() 3774 c.mu.Unlock() 3775 if uname != test.expectedUser { 3776 t.Fatalf("Expected selected user to be %q, got %q", test.expectedUser, uname) 3777 } 3778 if aname != test.expectedAcc { 3779 t.Fatalf("Expected selected account to be %q, got %q", test.expectedAcc, aname) 3780 } 3781 }) 3782 } 3783 } 3784 3785 func TestWSNkeyAuth(t *testing.T) { 3786 nkp, _ := nkeys.CreateUser() 3787 pub, _ := nkp.PublicKey() 3788 3789 wsnkp, _ := nkeys.CreateUser() 3790 wspub, _ := wsnkp.PublicKey() 3791 3792 badkp, _ := nkeys.CreateUser() 3793 badpub, _ := badkp.PublicKey() 3794 3795 for _, test := range []struct { 3796 name string 3797 opts func() *Options 3798 nkey string 3799 kp nkeys.KeyPair 3800 err string 3801 }{ 3802 { 3803 "no filtering, wrong nkey", 3804 func() *Options { 3805 o := testWSOptions() 3806 o.Nkeys = []*NkeyUser{{Nkey: pub}} 3807 return o 3808 }, 3809 badpub, badkp, "-ERR 'Authorization Violation'", 3810 }, 3811 { 3812 "no filtering, correct nkey", 3813 func() *Options { 3814 o := testWSOptions() 3815 o.Nkeys = []*NkeyUser{{Nkey: pub}} 3816 return o 3817 }, 3818 pub, nkp, "", 3819 }, 3820 { 3821 "filtering, nkey not allowed", 3822 func() *Options { 3823 o := testWSOptions() 3824 o.Nkeys = []*NkeyUser{ 3825 { 3826 Nkey: pub, 3827 AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard}), 3828 }, 3829 { 3830 Nkey: wspub, 3831 AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeWebsocket}), 3832 }, 3833 } 3834 return o 3835 }, 3836 pub, nkp, "-ERR 'Authorization Violation'", 3837 }, 3838 { 3839 "filtering, correct nkey", 3840 func() *Options { 3841 o := testWSOptions() 3842 o.Nkeys = []*NkeyUser{ 3843 {Nkey: pub}, 3844 { 3845 Nkey: wspub, 3846 AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}), 3847 }, 3848 } 3849 return o 3850 }, 3851 wspub, wsnkp, "", 3852 }, 3853 { 3854 "filtering, wrong nkey", 3855 func() *Options { 3856 o := testWSOptions() 3857 o.Nkeys = []*NkeyUser{ 3858 { 3859 Nkey: wspub, 3860 AllowedConnectionTypes: testCreateAllowedConnectionTypes([]string{jwt.ConnectionTypeStandard, jwt.ConnectionTypeWebsocket}), 3861 }, 3862 } 3863 return o 3864 }, 3865 badpub, badkp, "-ERR 'Authorization Violation'", 3866 }, 3867 } { 3868 t.Run(test.name, func(t *testing.T) { 3869 o := test.opts() 3870 s := RunServer(o) 3871 defer s.Shutdown() 3872 3873 wsc, br, infoMsg := testWSCreateClientGetInfo(t, false, false, o.Websocket.Host, o.Websocket.Port) 3874 defer wsc.Close() 3875 3876 // Sign Nonce 3877 var info nonceInfo 3878 json.Unmarshal([]byte(infoMsg[5:]), &info) 3879 sigraw, _ := test.kp.Sign([]byte(info.Nonce)) 3880 sig := base64.RawURLEncoding.EncodeToString(sigraw) 3881 3882 connectProto := fmt.Sprintf("CONNECT {\"verbose\":false,\"protocol\":1,\"nkey\":\"%s\",\"sig\":\"%s\"}\r\nPING\r\n", test.nkey, sig) 3883 3884 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte(connectProto)) 3885 if _, err := wsc.Write(wsmsg); err != nil { 3886 t.Fatalf("Error sending message: %v", err) 3887 } 3888 msg := testWSReadFrame(t, br) 3889 if test.err == "" && !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 3890 t.Fatalf("Expected to receive PONG, got %q", msg) 3891 } else if test.err != "" && !bytes.HasPrefix(msg, []byte(test.err)) { 3892 t.Fatalf("Expected to receive %q, got %q", test.err, msg) 3893 } 3894 }) 3895 } 3896 } 3897 3898 func TestWSSetHeaderServer(t *testing.T) { 3899 o := testWSOptions() 3900 o.Websocket.Headers = map[string]string{ 3901 "X-Custom-Header": "custom-value", 3902 } 3903 3904 s := RunServer(o) 3905 defer s.Shutdown() 3906 3907 opts := testWSClientOptions{ 3908 host: o.Websocket.Host, 3909 port: o.Websocket.Port, 3910 extraResponseHeaders: o.Websocket.Headers, 3911 } 3912 3913 c, _, _ := testNewWSClient(t, opts) 3914 defer c.Close() 3915 } 3916 3917 func TestWSJWTWithAllowedConnectionTypes(t *testing.T) { 3918 o := testWSOptions() 3919 setupAddTrusted(o) 3920 s := RunServer(o) 3921 buildMemAccResolver(s) 3922 defer s.Shutdown() 3923 3924 for _, test := range []struct { 3925 name string 3926 connectionTypes []string 3927 expectedAnswer string 3928 }{ 3929 {"not allowed", []string{jwt.ConnectionTypeStandard}, "-ERR"}, 3930 {"allowed", []string{jwt.ConnectionTypeStandard, strings.ToLower(jwt.ConnectionTypeWebsocket)}, "+OK"}, 3931 {"allowed with unknown", []string{jwt.ConnectionTypeWebsocket, "SomeNewType"}, "+OK"}, 3932 {"not allowed with unknown", []string{"SomeNewType"}, "-ERR"}, 3933 } { 3934 t.Run(test.name, func(t *testing.T) { 3935 nuc := newJWTTestUserClaims() 3936 nuc.AllowedConnectionTypes = test.connectionTypes 3937 claimOpt := testClaimsOptions{ 3938 nuc: nuc, 3939 expectAnswer: test.expectedAnswer, 3940 } 3941 _, c, _, _ := testWSWithClaims(t, s, testWSClientOptions{host: o.Websocket.Host, port: o.Websocket.Port}, claimOpt) 3942 c.Close() 3943 }) 3944 } 3945 } 3946 3947 func TestWSJWTCookieUser(t *testing.T) { 3948 nucSigFunc := func() *jwt.UserClaims { return newJWTTestUserClaims() } 3949 nucBearerFunc := func() *jwt.UserClaims { 3950 ret := newJWTTestUserClaims() 3951 ret.BearerToken = true 3952 return ret 3953 } 3954 3955 o := testWSOptions() 3956 setupAddTrusted(o) 3957 setupAddCookie(o) 3958 s := RunServer(o) 3959 buildMemAccResolver(s) 3960 defer s.Shutdown() 3961 3962 genJwt := func(t *testing.T, nuc *jwt.UserClaims) string { 3963 okp, _ := nkeys.FromSeed(oSeed) 3964 3965 akp, _ := nkeys.CreateAccount() 3966 apub, _ := akp.PublicKey() 3967 3968 nac := jwt.NewAccountClaims(apub) 3969 ajwt, err := nac.Encode(okp) 3970 if err != nil { 3971 t.Fatalf("Error generating account JWT: %v", err) 3972 } 3973 3974 nkp, _ := nkeys.CreateUser() 3975 pub, _ := nkp.PublicKey() 3976 nuc.Subject = pub 3977 jwt, err := nuc.Encode(akp) 3978 if err != nil { 3979 t.Fatalf("Error generating user JWT: %v", err) 3980 } 3981 addAccountToMemResolver(s, apub, ajwt) 3982 return jwt 3983 } 3984 3985 cliOpts := testWSClientOptions{ 3986 host: o.Websocket.Host, 3987 port: o.Websocket.Port, 3988 } 3989 for _, test := range []struct { 3990 name string 3991 nuc *jwt.UserClaims 3992 opts func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) 3993 expectAnswer string 3994 }{ 3995 { 3996 name: "protocol auth, non-bearer key, with signature", 3997 nuc: nucSigFunc(), 3998 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 3999 return cliOpts, testClaimsOptions{nuc: claims} 4000 }, 4001 expectAnswer: "+OK", 4002 }, 4003 { 4004 name: "protocol auth, non-bearer key, w/o required signature", 4005 nuc: nucSigFunc(), 4006 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 4007 return cliOpts, testClaimsOptions{nuc: claims, dontSign: true} 4008 }, 4009 expectAnswer: "-ERR", 4010 }, 4011 { 4012 name: "protocol auth, bearer key, w/o signature", 4013 nuc: nucBearerFunc(), 4014 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 4015 return cliOpts, testClaimsOptions{nuc: claims, dontSign: true} 4016 }, 4017 expectAnswer: "+OK", 4018 }, 4019 { 4020 name: "cookie auth, non-bearer key, protocol auth fail", 4021 nuc: nucSigFunc(), 4022 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 4023 co := cliOpts 4024 co.extraHeaders = map[string][]string{} 4025 co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)} 4026 return co, testClaimsOptions{connectRequest: struct{}{}} 4027 }, 4028 expectAnswer: "-ERR", 4029 }, 4030 { 4031 name: "cookie auth, bearer key, protocol auth success with implied cookie jwt", 4032 nuc: nucBearerFunc(), 4033 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 4034 co := cliOpts 4035 co.extraHeaders = map[string][]string{} 4036 co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)} 4037 return co, testClaimsOptions{connectRequest: struct{}{}} 4038 }, 4039 expectAnswer: "+OK", 4040 }, 4041 { 4042 name: "cookie auth, non-bearer key, protocol auth success via override jwt in CONNECT opts", 4043 nuc: nucSigFunc(), 4044 opts: func(t *testing.T, claims *jwt.UserClaims) (testWSClientOptions, testClaimsOptions) { 4045 co := cliOpts 4046 co.extraHeaders = map[string][]string{} 4047 co.extraHeaders["Cookie"] = []string{o.Websocket.JWTCookie + "=" + genJwt(t, claims)} 4048 return co, testClaimsOptions{nuc: nucBearerFunc()} 4049 }, 4050 expectAnswer: "+OK", 4051 }, 4052 } { 4053 t.Run(test.name, func(t *testing.T) { 4054 cliOpt, claimOpt := test.opts(t, test.nuc) 4055 claimOpt.expectAnswer = test.expectAnswer 4056 _, c, _, _ := testWSWithClaims(t, s, cliOpt, claimOpt) 4057 c.Close() 4058 }) 4059 } 4060 s.Shutdown() 4061 } 4062 4063 func TestWSReloadTLSConfig(t *testing.T) { 4064 template := ` 4065 listen: "127.0.0.1:-1" 4066 websocket { 4067 listen: "127.0.0.1:-1" 4068 tls { 4069 cert_file: '%s' 4070 key_file: '%s' 4071 ca_file: '../test/configs/certs/ca.pem' 4072 } 4073 } 4074 ` 4075 conf := createConfFile(t, []byte(fmt.Sprintf(template, 4076 "../test/configs/certs/server-noip.pem", 4077 "../test/configs/certs/server-key-noip.pem"))) 4078 4079 s, o := RunServerWithConfig(conf) 4080 defer s.Shutdown() 4081 4082 addr := fmt.Sprintf("127.0.0.1:%d", o.Websocket.Port) 4083 wsc, err := net.Dial("tcp", addr) 4084 if err != nil { 4085 t.Fatalf("Error creating ws connection: %v", err) 4086 } 4087 defer wsc.Close() 4088 4089 tc := &TLSConfigOpts{CaFile: "../test/configs/certs/ca.pem"} 4090 tlsConfig, err := GenTLSConfig(tc) 4091 if err != nil { 4092 t.Fatalf("Error generating TLS config: %v", err) 4093 } 4094 tlsConfig.ServerName = "127.0.0.1" 4095 tlsConfig.RootCAs = tlsConfig.ClientCAs 4096 tlsConfig.ClientCAs = nil 4097 wsc = tls.Client(wsc, tlsConfig.Clone()) 4098 if err := wsc.(*tls.Conn).Handshake(); err == nil || !strings.Contains(err.Error(), "SAN") { 4099 t.Fatalf("Unexpected error: %v", err) 4100 } 4101 wsc.Close() 4102 4103 reloadUpdateConfig(t, s, conf, fmt.Sprintf(template, 4104 "../test/configs/certs/server-cert.pem", 4105 "../test/configs/certs/server-key.pem")) 4106 4107 wsc, err = net.Dial("tcp", addr) 4108 if err != nil { 4109 t.Fatalf("Error creating ws connection: %v", err) 4110 } 4111 defer wsc.Close() 4112 4113 wsc = tls.Client(wsc, tlsConfig.Clone()) 4114 if err := wsc.(*tls.Conn).Handshake(); err != nil { 4115 t.Fatalf("Error on TLS handshake: %v", err) 4116 } 4117 } 4118 4119 type captureClientConnectedLogger struct { 4120 DummyLogger 4121 ch chan string 4122 } 4123 4124 func (l *captureClientConnectedLogger) Debugf(format string, v ...any) { 4125 msg := fmt.Sprintf(format, v...) 4126 if !strings.Contains(msg, "Client connection created") { 4127 return 4128 } 4129 select { 4130 case l.ch <- msg: 4131 default: 4132 } 4133 } 4134 4135 func TestWSXForwardedFor(t *testing.T) { 4136 o := testWSOptions() 4137 s := RunServer(o) 4138 defer s.Shutdown() 4139 4140 l := &captureClientConnectedLogger{ch: make(chan string, 1)} 4141 s.SetLogger(l, true, false) 4142 4143 for _, test := range []struct { 4144 name string 4145 headers func() map[string][]string 4146 useHdrValue bool 4147 expectedValue string 4148 }{ 4149 {"nil map", func() map[string][]string { 4150 return nil 4151 }, false, _EMPTY_}, 4152 {"empty map", func() map[string][]string { 4153 return make(map[string][]string) 4154 }, false, _EMPTY_}, 4155 {"header present empty value", func() map[string][]string { 4156 m := make(map[string][]string) 4157 m[wsXForwardedForHeader] = []string{} 4158 return m 4159 }, false, _EMPTY_}, 4160 {"header present invalid IP", func() map[string][]string { 4161 m := make(map[string][]string) 4162 m[wsXForwardedForHeader] = []string{"not a valid IP"} 4163 return m 4164 }, false, _EMPTY_}, 4165 {"header present one IP", func() map[string][]string { 4166 m := make(map[string][]string) 4167 m[wsXForwardedForHeader] = []string{"1.2.3.4"} 4168 return m 4169 }, true, "1.2.3.4"}, 4170 {"header present multiple IPs", func() map[string][]string { 4171 m := make(map[string][]string) 4172 m[wsXForwardedForHeader] = []string{"1.2.3.4", "5.6.7.8"} 4173 return m 4174 }, true, "1.2.3.4"}, 4175 {"header present IPv6", func() map[string][]string { 4176 m := make(map[string][]string) 4177 m[wsXForwardedForHeader] = []string{"::1"} 4178 return m 4179 }, true, "[::1]"}, 4180 } { 4181 t.Run(test.name, func(t *testing.T) { 4182 c, r, _ := testNewWSClient(t, testWSClientOptions{ 4183 host: o.Websocket.Host, 4184 port: o.Websocket.Port, 4185 extraHeaders: test.headers(), 4186 }) 4187 defer c.Close() 4188 // Send CONNECT and PING 4189 wsmsg := testWSCreateClientMsg(wsBinaryMessage, 1, true, false, []byte("CONNECT {\"verbose\":false,\"protocol\":1}\r\nPING\r\n")) 4190 if _, err := c.Write(wsmsg); err != nil { 4191 t.Fatalf("Error sending message: %v", err) 4192 } 4193 // Wait for the PONG 4194 if msg := testWSReadFrame(t, r); !bytes.HasPrefix(msg, []byte("PONG\r\n")) { 4195 t.Fatalf("Expected PONG, got %s", msg) 4196 } 4197 select { 4198 case d := <-l.ch: 4199 ipAndSlash := fmt.Sprintf("%s/", test.expectedValue) 4200 if test.useHdrValue { 4201 if !strings.HasPrefix(d, ipAndSlash) { 4202 t.Fatalf("Expected debug statement to start with: %q, got %q", ipAndSlash, d) 4203 } 4204 } else if strings.HasPrefix(d, ipAndSlash) { 4205 t.Fatalf("Unexpected debug statement: %q", d) 4206 } 4207 case <-time.After(time.Second): 4208 t.Fatal("Did not get connect debug statement") 4209 } 4210 }) 4211 } 4212 } 4213 4214 type partialWriteConn struct { 4215 net.Conn 4216 } 4217 4218 func (c *partialWriteConn) Write(b []byte) (int, error) { 4219 max := len(b) 4220 if max > 0 { 4221 max = rand.Intn(max) 4222 if max == 0 { 4223 max = 1 4224 } 4225 } 4226 n, err := c.Conn.Write(b[:max]) 4227 if err == nil && max != len(b) { 4228 err = io.ErrShortWrite 4229 } 4230 return n, err 4231 } 4232 4233 func TestWSWithPartialWrite(t *testing.T) { 4234 conf := createConfFile(t, []byte(` 4235 listen: "127.0.0.1:-1" 4236 websocket { 4237 listen: "127.0.0.1:-1" 4238 no_tls: true 4239 } 4240 `)) 4241 s, o := RunServerWithConfig(conf) 4242 defer s.Shutdown() 4243 4244 nc1 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o.Websocket.Port)) 4245 defer nc1.Close() 4246 4247 sub := natsSubSync(t, nc1, "foo") 4248 sub.SetPendingLimits(-1, -1) 4249 natsFlush(t, nc1) 4250 4251 nc2 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o.Websocket.Port)) 4252 defer nc2.Close() 4253 4254 // Replace websocket connections with ones that will produce short writes. 4255 s.mu.RLock() 4256 for _, c := range s.clients { 4257 c.mu.Lock() 4258 c.nc = &partialWriteConn{Conn: c.nc} 4259 c.mu.Unlock() 4260 } 4261 s.mu.RUnlock() 4262 4263 var msgs [][]byte 4264 for i := 0; i < 100; i++ { 4265 msg := make([]byte, rand.Intn(10000)+10) 4266 for j := 0; j < len(msg); j++ { 4267 msg[j] = byte('A' + j%26) 4268 } 4269 msgs = append(msgs, msg) 4270 natsPub(t, nc2, "foo", msg) 4271 } 4272 for i := 0; i < 100; i++ { 4273 rmsg := natsNexMsg(t, sub, time.Second) 4274 if !bytes.Equal(msgs[i], rmsg.Data) { 4275 t.Fatalf("Expected message %q, got %q", msgs[i], rmsg.Data) 4276 } 4277 } 4278 } 4279 4280 func testWSNoCorruptionWithFrameSizeLimit(t *testing.T, total int) { 4281 tmpl := ` 4282 listen: "127.0.0.1:-1" 4283 cluster { 4284 name: "local" 4285 port: -1 4286 %s 4287 } 4288 websocket { 4289 listen: "127.0.0.1:-1" 4290 no_tls: true 4291 } 4292 ` 4293 conf1 := createConfFile(t, []byte(fmt.Sprintf(tmpl, _EMPTY_))) 4294 s1, o1 := RunServerWithConfig(conf1) 4295 defer s1.Shutdown() 4296 4297 routes := fmt.Sprintf("routes: [\"nats://127.0.0.1:%d\"]", o1.Cluster.Port) 4298 conf2 := createConfFile(t, []byte(fmt.Sprintf(tmpl, routes))) 4299 s2, o2 := RunServerWithConfig(conf2) 4300 defer s2.Shutdown() 4301 4302 conf3 := createConfFile(t, []byte(fmt.Sprintf(tmpl, routes))) 4303 s3, o3 := RunServerWithConfig(conf3) 4304 defer s3.Shutdown() 4305 4306 checkClusterFormed(t, s1, s2, s3) 4307 4308 nc3 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o3.Websocket.Port)) 4309 defer nc3.Close() 4310 4311 nc2 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o2.Websocket.Port)) 4312 defer nc2.Close() 4313 4314 payload := make([]byte, 100000) 4315 for i := 0; i < len(payload); i++ { 4316 payload[i] = 'A' + byte(i%26) 4317 } 4318 errCh := make(chan error, 1) 4319 doneCh := make(chan struct{}, 1) 4320 count := int32(0) 4321 4322 createSub := func(nc *nats.Conn) { 4323 sub := natsSub(t, nc, "foo", func(m *nats.Msg) { 4324 if !bytes.Equal(m.Data, payload) { 4325 stop := len(m.Data) 4326 if l := len(payload); l < stop { 4327 stop = l 4328 } 4329 start := 0 4330 for i := 0; i < stop; i++ { 4331 if m.Data[i] != payload[i] { 4332 start = i 4333 break 4334 } 4335 } 4336 if stop-start > 20 { 4337 stop = start + 20 4338 } 4339 select { 4340 case errCh <- fmt.Errorf("Invalid message: [%d bytes same]%s[...]", start, m.Data[start:stop]): 4341 default: 4342 } 4343 return 4344 } 4345 if n := atomic.AddInt32(&count, 1); int(n) == 2*total { 4346 doneCh <- struct{}{} 4347 } 4348 }) 4349 sub.SetPendingLimits(-1, -1) 4350 } 4351 createSub(nc2) 4352 createSub(nc3) 4353 4354 checkSubInterest(t, s1, globalAccountName, "foo", time.Second) 4355 4356 nc1 := natsConnect(t, fmt.Sprintf("ws://127.0.0.1:%d", o1.Websocket.Port)) 4357 defer nc1.Close() 4358 natsFlush(t, nc1) 4359 4360 // Change websocket connections to force a max frame size. 4361 for _, s := range []*Server{s1, s2, s3} { 4362 s.mu.RLock() 4363 for _, c := range s.clients { 4364 c.mu.Lock() 4365 if c.ws != nil { 4366 c.ws.browser = true 4367 } 4368 c.mu.Unlock() 4369 } 4370 s.mu.RUnlock() 4371 } 4372 4373 for i := 0; i < total; i++ { 4374 natsPub(t, nc1, "foo", payload) 4375 if i%100 == 0 { 4376 select { 4377 case err := <-errCh: 4378 t.Fatalf("Error: %v", err) 4379 default: 4380 } 4381 } 4382 } 4383 select { 4384 case err := <-errCh: 4385 t.Fatalf("Error: %v", err) 4386 case <-doneCh: 4387 return 4388 case <-time.After(10 * time.Second): 4389 t.Fatalf("Test timed out") 4390 } 4391 } 4392 4393 func TestWSNoCorruptionWithFrameSizeLimit(t *testing.T) { 4394 testWSNoCorruptionWithFrameSizeLimit(t, 1000) 4395 } 4396 4397 // ================================================================== 4398 // = Benchmark tests 4399 // ================================================================== 4400 4401 const testWSBenchSubject = "a" 4402 4403 var ch = []byte("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789!@$#%^&*()") 4404 4405 func sizedString(sz int) string { 4406 b := make([]byte, sz) 4407 for i := range b { 4408 b[i] = ch[rand.Intn(len(ch))] 4409 } 4410 return string(b) 4411 } 4412 4413 func sizedStringForCompression(sz int) string { 4414 b := make([]byte, sz) 4415 c := byte(0) 4416 s := 0 4417 for i := range b { 4418 if s%20 == 0 { 4419 c = ch[rand.Intn(len(ch))] 4420 } 4421 b[i] = c 4422 } 4423 return string(b) 4424 } 4425 4426 func testWSFlushConn(b *testing.B, compress bool, c net.Conn, br *bufio.Reader) { 4427 buf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, []byte(pingProto)) 4428 c.Write(buf) 4429 c.SetReadDeadline(time.Now().Add(5 * time.Second)) 4430 res := testWSReadFrame(b, br) 4431 c.SetReadDeadline(time.Time{}) 4432 if !bytes.HasPrefix(res, []byte(pongProto)) { 4433 b.Fatalf("Failed read of PONG: %s\n", res) 4434 } 4435 } 4436 4437 func wsBenchPub(b *testing.B, numPubs int, compress bool, payload string) { 4438 b.StopTimer() 4439 opts := testWSOptions() 4440 opts.Websocket.Compression = compress 4441 s := RunServer(opts) 4442 defer s.Shutdown() 4443 4444 extra := 0 4445 pubProto := []byte(fmt.Sprintf("PUB %s %d\r\n%s\r\n", testWSBenchSubject, len(payload), payload)) 4446 singleOpBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, pubProto) 4447 4448 // Simulate client that would buffer messages before framing/sending. 4449 // Figure out how many we can fit in one frame based on b.N and length of pubProto 4450 const bufSize = 32768 4451 tmpa := [bufSize]byte{} 4452 tmp := tmpa[:0] 4453 pb := 0 4454 for i := 0; i < b.N; i++ { 4455 tmp = append(tmp, pubProto...) 4456 pb++ 4457 if len(tmp) >= bufSize { 4458 break 4459 } 4460 } 4461 sendBuf := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, tmp) 4462 n := b.N / pb 4463 extra = b.N - (n * pb) 4464 4465 wg := sync.WaitGroup{} 4466 wg.Add(numPubs) 4467 4468 type pub struct { 4469 c net.Conn 4470 br *bufio.Reader 4471 bw *bufio.Writer 4472 } 4473 var pubs []pub 4474 for i := 0; i < numPubs; i++ { 4475 wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port) 4476 defer wsc.Close() 4477 bw := bufio.NewWriterSize(wsc, bufSize) 4478 pubs = append(pubs, pub{wsc, br, bw}) 4479 } 4480 4481 // Average the amount of bytes sent by iteration 4482 avg := len(sendBuf) / pb 4483 if extra > 0 { 4484 avg += len(singleOpBuf) 4485 avg /= 2 4486 } 4487 b.SetBytes(int64(numPubs * avg)) 4488 b.StartTimer() 4489 4490 for i := 0; i < numPubs; i++ { 4491 p := pubs[i] 4492 go func(p pub) { 4493 defer wg.Done() 4494 for i := 0; i < n; i++ { 4495 p.bw.Write(sendBuf) 4496 } 4497 for i := 0; i < extra; i++ { 4498 p.bw.Write(singleOpBuf) 4499 } 4500 p.bw.Flush() 4501 testWSFlushConn(b, compress, p.c, p.br) 4502 }(p) 4503 } 4504 wg.Wait() 4505 b.StopTimer() 4506 } 4507 4508 func Benchmark_WS_Pubx1_CN_____0b(b *testing.B) { 4509 wsBenchPub(b, 1, false, "") 4510 } 4511 4512 func Benchmark_WS_Pubx1_CY_____0b(b *testing.B) { 4513 wsBenchPub(b, 1, true, "") 4514 } 4515 4516 func Benchmark_WS_Pubx1_CN___128b(b *testing.B) { 4517 s := sizedString(128) 4518 wsBenchPub(b, 1, false, s) 4519 } 4520 4521 func Benchmark_WS_Pubx1_CY___128b(b *testing.B) { 4522 s := sizedStringForCompression(128) 4523 wsBenchPub(b, 1, true, s) 4524 } 4525 4526 func Benchmark_WS_Pubx1_CN__1024b(b *testing.B) { 4527 s := sizedString(1024) 4528 wsBenchPub(b, 1, false, s) 4529 } 4530 4531 func Benchmark_WS_Pubx1_CY__1024b(b *testing.B) { 4532 s := sizedStringForCompression(1024) 4533 wsBenchPub(b, 1, true, s) 4534 } 4535 4536 func Benchmark_WS_Pubx1_CN__4096b(b *testing.B) { 4537 s := sizedString(4 * 1024) 4538 wsBenchPub(b, 1, false, s) 4539 } 4540 4541 func Benchmark_WS_Pubx1_CY__4096b(b *testing.B) { 4542 s := sizedStringForCompression(4 * 1024) 4543 wsBenchPub(b, 1, true, s) 4544 } 4545 4546 func Benchmark_WS_Pubx1_CN__8192b(b *testing.B) { 4547 s := sizedString(8 * 1024) 4548 wsBenchPub(b, 1, false, s) 4549 } 4550 4551 func Benchmark_WS_Pubx1_CY__8192b(b *testing.B) { 4552 s := sizedStringForCompression(8 * 1024) 4553 wsBenchPub(b, 1, true, s) 4554 } 4555 4556 func Benchmark_WS_Pubx1_CN_32768b(b *testing.B) { 4557 s := sizedString(32 * 1024) 4558 wsBenchPub(b, 1, false, s) 4559 } 4560 4561 func Benchmark_WS_Pubx1_CY_32768b(b *testing.B) { 4562 s := sizedStringForCompression(32 * 1024) 4563 wsBenchPub(b, 1, true, s) 4564 } 4565 4566 func Benchmark_WS_Pubx5_CN_____0b(b *testing.B) { 4567 wsBenchPub(b, 5, false, "") 4568 } 4569 4570 func Benchmark_WS_Pubx5_CY_____0b(b *testing.B) { 4571 wsBenchPub(b, 5, true, "") 4572 } 4573 4574 func Benchmark_WS_Pubx5_CN___128b(b *testing.B) { 4575 s := sizedString(128) 4576 wsBenchPub(b, 5, false, s) 4577 } 4578 4579 func Benchmark_WS_Pubx5_CY___128b(b *testing.B) { 4580 s := sizedStringForCompression(128) 4581 wsBenchPub(b, 5, true, s) 4582 } 4583 4584 func Benchmark_WS_Pubx5_CN__1024b(b *testing.B) { 4585 s := sizedString(1024) 4586 wsBenchPub(b, 5, false, s) 4587 } 4588 4589 func Benchmark_WS_Pubx5_CY__1024b(b *testing.B) { 4590 s := sizedStringForCompression(1024) 4591 wsBenchPub(b, 5, true, s) 4592 } 4593 4594 func Benchmark_WS_Pubx5_CN__4096b(b *testing.B) { 4595 s := sizedString(4 * 1024) 4596 wsBenchPub(b, 5, false, s) 4597 } 4598 4599 func Benchmark_WS_Pubx5_CY__4096b(b *testing.B) { 4600 s := sizedStringForCompression(4 * 1024) 4601 wsBenchPub(b, 5, true, s) 4602 } 4603 4604 func Benchmark_WS_Pubx5_CN__8192b(b *testing.B) { 4605 s := sizedString(8 * 1024) 4606 wsBenchPub(b, 5, false, s) 4607 } 4608 4609 func Benchmark_WS_Pubx5_CY__8192b(b *testing.B) { 4610 s := sizedStringForCompression(8 * 1024) 4611 wsBenchPub(b, 5, true, s) 4612 } 4613 4614 func Benchmark_WS_Pubx5_CN_32768b(b *testing.B) { 4615 s := sizedString(32 * 1024) 4616 wsBenchPub(b, 5, false, s) 4617 } 4618 4619 func Benchmark_WS_Pubx5_CY_32768b(b *testing.B) { 4620 s := sizedStringForCompression(32 * 1024) 4621 wsBenchPub(b, 5, true, s) 4622 } 4623 4624 func wsBenchSub(b *testing.B, numSubs int, compress bool, payload string) { 4625 b.StopTimer() 4626 opts := testWSOptions() 4627 opts.Websocket.Compression = compress 4628 s := RunServer(opts) 4629 defer s.Shutdown() 4630 4631 var subs []*bufio.Reader 4632 for i := 0; i < numSubs; i++ { 4633 wsc, br := testWSCreateClient(b, compress, false, opts.Websocket.Host, opts.Websocket.Port) 4634 defer wsc.Close() 4635 subProto := testWSCreateClientMsg(wsBinaryMessage, 1, true, compress, 4636 []byte(fmt.Sprintf("SUB %s 1\r\nPING\r\n", testWSBenchSubject))) 4637 wsc.Write(subProto) 4638 // Waiting for PONG 4639 testWSReadFrame(b, br) 4640 subs = append(subs, br) 4641 } 4642 4643 wg := sync.WaitGroup{} 4644 wg.Add(numSubs) 4645 4646 // Use regular NATS client to publish messages 4647 nc := natsConnect(b, s.ClientURL()) 4648 defer nc.Close() 4649 4650 b.StartTimer() 4651 4652 for i := 0; i < numSubs; i++ { 4653 br := subs[i] 4654 go func(br *bufio.Reader) { 4655 defer wg.Done() 4656 for count := 0; count < b.N; { 4657 msgs := testWSReadFrame(b, br) 4658 count += bytes.Count(msgs, []byte("MSG ")) 4659 } 4660 }(br) 4661 } 4662 for i := 0; i < b.N; i++ { 4663 natsPub(b, nc, testWSBenchSubject, []byte(payload)) 4664 } 4665 wg.Wait() 4666 b.StopTimer() 4667 } 4668 4669 func Benchmark_WS_Subx1_CN_____0b(b *testing.B) { 4670 wsBenchSub(b, 1, false, "") 4671 } 4672 4673 func Benchmark_WS_Subx1_CY_____0b(b *testing.B) { 4674 wsBenchSub(b, 1, true, "") 4675 } 4676 4677 func Benchmark_WS_Subx1_CN___128b(b *testing.B) { 4678 s := sizedString(128) 4679 wsBenchSub(b, 1, false, s) 4680 } 4681 4682 func Benchmark_WS_Subx1_CY___128b(b *testing.B) { 4683 s := sizedStringForCompression(128) 4684 wsBenchSub(b, 1, true, s) 4685 } 4686 4687 func Benchmark_WS_Subx1_CN__1024b(b *testing.B) { 4688 s := sizedString(1024) 4689 wsBenchSub(b, 1, false, s) 4690 } 4691 4692 func Benchmark_WS_Subx1_CY__1024b(b *testing.B) { 4693 s := sizedStringForCompression(1024) 4694 wsBenchSub(b, 1, true, s) 4695 } 4696 4697 func Benchmark_WS_Subx1_CN__4096b(b *testing.B) { 4698 s := sizedString(4096) 4699 wsBenchSub(b, 1, false, s) 4700 } 4701 4702 func Benchmark_WS_Subx1_CY__4096b(b *testing.B) { 4703 s := sizedStringForCompression(4096) 4704 wsBenchSub(b, 1, true, s) 4705 } 4706 4707 func Benchmark_WS_Subx1_CN__8192b(b *testing.B) { 4708 s := sizedString(8192) 4709 wsBenchSub(b, 1, false, s) 4710 } 4711 4712 func Benchmark_WS_Subx1_CY__8192b(b *testing.B) { 4713 s := sizedStringForCompression(8192) 4714 wsBenchSub(b, 1, true, s) 4715 } 4716 4717 func Benchmark_WS_Subx1_CN_32768b(b *testing.B) { 4718 s := sizedString(32768) 4719 wsBenchSub(b, 1, false, s) 4720 } 4721 4722 func Benchmark_WS_Subx1_CY_32768b(b *testing.B) { 4723 s := sizedStringForCompression(32768) 4724 wsBenchSub(b, 1, true, s) 4725 } 4726 4727 func Benchmark_WS_Subx5_CN_____0b(b *testing.B) { 4728 wsBenchSub(b, 5, false, "") 4729 } 4730 4731 func Benchmark_WS_Subx5_CY_____0b(b *testing.B) { 4732 wsBenchSub(b, 5, true, "") 4733 } 4734 4735 func Benchmark_WS_Subx5_CN___128b(b *testing.B) { 4736 s := sizedString(128) 4737 wsBenchSub(b, 5, false, s) 4738 } 4739 4740 func Benchmark_WS_Subx5_CY___128b(b *testing.B) { 4741 s := sizedStringForCompression(128) 4742 wsBenchSub(b, 5, true, s) 4743 } 4744 4745 func Benchmark_WS_Subx5_CN__1024b(b *testing.B) { 4746 s := sizedString(1024) 4747 wsBenchSub(b, 5, false, s) 4748 } 4749 4750 func Benchmark_WS_Subx5_CY__1024b(b *testing.B) { 4751 s := sizedStringForCompression(1024) 4752 wsBenchSub(b, 5, true, s) 4753 } 4754 4755 func Benchmark_WS_Subx5_CN__4096b(b *testing.B) { 4756 s := sizedString(4096) 4757 wsBenchSub(b, 5, false, s) 4758 } 4759 4760 func Benchmark_WS_Subx5_CY__4096b(b *testing.B) { 4761 s := sizedStringForCompression(4096) 4762 wsBenchSub(b, 5, true, s) 4763 } 4764 4765 func Benchmark_WS_Subx5_CN__8192b(b *testing.B) { 4766 s := sizedString(8192) 4767 wsBenchSub(b, 5, false, s) 4768 } 4769 4770 func Benchmark_WS_Subx5_CY__8192b(b *testing.B) { 4771 s := sizedStringForCompression(8192) 4772 wsBenchSub(b, 5, true, s) 4773 } 4774 4775 func Benchmark_WS_Subx5_CN_32768b(b *testing.B) { 4776 s := sizedString(32768) 4777 wsBenchSub(b, 5, false, s) 4778 } 4779 4780 func Benchmark_WS_Subx5_CY_32768b(b *testing.B) { 4781 s := sizedStringForCompression(32768) 4782 wsBenchSub(b, 5, true, s) 4783 }