decred.org/dcrdex@v1.0.5/server/comms/comms_test.go (about) 1 package comms 2 3 import ( 4 "bytes" 5 "context" 6 "crypto/tls" 7 "crypto/x509" 8 "encoding/json" 9 "errors" 10 "fmt" 11 "net/http" 12 "net/http/httptest" 13 "net/url" 14 "os" 15 "path/filepath" 16 "strings" 17 "sync" 18 "sync/atomic" 19 "testing" 20 "time" 21 22 "decred.org/dcrdex/dex" 23 "decred.org/dcrdex/dex/msgjson" 24 "decred.org/dcrdex/dex/ws" 25 "github.com/gorilla/websocket" 26 ) 27 28 var ( 29 tErr = fmt.Errorf("test error") 30 testCtx context.Context 31 tLogger = dex.StdOutLogger("TCOMMS", dex.LevelTrace) 32 ) 33 34 func newServer() *Server { 35 s := &Server{ 36 clients: make(map[uint64]*wsLink), 37 wsLimiters: make(map[dex.IPKey]*ipWsLimiter), 38 v6Prefixes: make(map[dex.IPKey]int), 39 quarantine: make(map[dex.IPKey]time.Time), 40 dataEnabled: 1, 41 rpcRoutes: make(map[string]MsgHandler), 42 httpRoutes: make(map[string]HTTPHandler), 43 } 44 for _, route := range []string{msgjson.ConfigRoute, msgjson.SpotsRoute, msgjson.CandlesRoute, msgjson.OrderBookRoute} { 45 s.RegisterHTTP(route, func(any) (any, error) { return nil, nil }) 46 } 47 return s 48 } 49 50 func giveItASecond(f func() bool) bool { 51 ticker := time.NewTicker(time.Millisecond) 52 timeout := time.NewTimer(time.Second) 53 for { 54 if f() { 55 return true 56 } 57 select { 58 case <-timeout.C: 59 return false 60 default: 61 } 62 <-ticker.C 63 } 64 } 65 66 func readChannel(t *testing.T, tag string, c chan any) any { 67 t.Helper() 68 select { 69 case i := <-c: 70 return i 71 case <-time.NewTimer(time.Second).C: 72 t.Fatalf("%s: didn't read channel", tag) 73 } 74 return nil 75 } 76 77 func decodeResponse(t *testing.T, b []byte) *msgjson.ResponsePayload { 78 t.Helper() 79 msg, err := msgjson.DecodeMessage(b) 80 if err != nil { 81 t.Fatalf("error decoding last message (%s): %v", string(b), err) 82 } 83 resp, err := msg.Response() 84 if err != nil { 85 t.Fatalf("error decoding response payload: %v", err) 86 } 87 return resp 88 } 89 90 type wsConnStub struct { 91 msg chan []byte 92 quit chan struct{} 93 close int 94 recv chan []byte 95 nextRead chan struct{} // helps detect when (*WSLink).inHandler is running 96 writeMtx sync.Mutex 97 writeErr error 98 } 99 100 func (conn *wsConnStub) addChan() { 101 conn.recv = make(chan []byte) 102 } 103 104 func (conn *wsConnStub) addNextChan() { 105 conn.nextRead = make(chan struct{}, 1) // send when ReadMessage() is called 106 } 107 108 func (conn *wsConnStub) wait(t *testing.T, tag string) { 109 t.Helper() 110 select { 111 case <-conn.recv: 112 case <-time.NewTimer(time.Second).C: 113 t.Fatalf("%s - wait timeout", tag) 114 } 115 } 116 117 func newWsStub() *wsConnStub { 118 return &wsConnStub{ 119 msg: make(chan []byte), 120 // recv is nil unless a test wants to receive 121 quit: make(chan struct{}), 122 } 123 } 124 125 func (conn *wsConnStub) setWriteErr(err error) { 126 conn.writeMtx.Lock() 127 conn.writeErr = err 128 conn.writeMtx.Unlock() 129 } 130 131 // nonEOF can specify a particular error should be returned through ReadMessage. 132 var nonEOF = make(chan struct{}) 133 var pongTrigger = []byte("pong") 134 135 func (conn *wsConnStub) ReadMessage() (int, []byte, error) { 136 if conn.nextRead != nil { 137 conn.nextRead <- struct{}{} 138 } 139 140 var b []byte 141 select { 142 case b = <-conn.msg: 143 if bytes.Equal(b, pongTrigger) { 144 return websocket.PongMessage, []byte{}, nil 145 } 146 case <-conn.quit: 147 return 0, nil, &websocket.CloseError{Code: websocket.CloseGoingAway, Text: "bye"} 148 case <-testCtx.Done(): 149 return 0, nil, &websocket.CloseError{Code: websocket.CloseGoingAway, Text: "bye"} 150 case <-nonEOF: 151 close(conn.quit) 152 return 0, nil, fmt.Errorf("test nonEOF error") 153 } 154 return 0, b, nil 155 } 156 157 func (conn *wsConnStub) WriteMessage(msgType int, msg []byte) error { 158 conn.writeMtx.Lock() 159 defer conn.writeMtx.Unlock() 160 if msgType == websocket.PingMessage { 161 select { 162 case conn.msg <- pongTrigger: 163 default: 164 } 165 return nil 166 } 167 // Send the message if their is a receiver for the current test. 168 if conn.recv != nil { 169 conn.recv <- msg 170 } 171 if conn.writeErr == nil { 172 return nil 173 } 174 err := conn.writeErr 175 conn.writeErr = nil 176 return err 177 } 178 179 func (conn *wsConnStub) SetReadLimit(int64) {} 180 181 func (conn *wsConnStub) SetWriteDeadline(t time.Time) error { 182 return nil // TODO implement and test write timeouts 183 } 184 185 func (conn *wsConnStub) SetReadDeadline(t time.Time) error { 186 return nil 187 } 188 189 func (conn *wsConnStub) WriteControl(messageType int, data []byte, deadline time.Time) error { 190 return nil 191 } 192 193 func (conn *wsConnStub) Close() error { 194 select { 195 case <-conn.quit: 196 default: 197 close(conn.quit) 198 } 199 conn.close++ 200 return nil 201 } 202 203 func dummyRPCHandler(_ Link, _ *msgjson.Message) *msgjson.Error { 204 return nil 205 } 206 207 var reqID uint64 208 209 func makeReq(route, msg string) *msgjson.Message { 210 reqID++ 211 req, err := msgjson.NewRequest(reqID, route, json.RawMessage(msg)) 212 if err != nil { 213 panic("bad request message") 214 } 215 return req 216 } 217 218 func makeResp(id uint64, msg string) *msgjson.Message { 219 resp, _ := msgjson.NewResponse(id, json.RawMessage(msg), nil) 220 return resp 221 } 222 223 func makeNtfn(route, msg string) *msgjson.Message { 224 ntfn, _ := msgjson.NewNotification(route, json.RawMessage(msg)) 225 return ntfn 226 } 227 228 func sendToConn(t *testing.T, conn *wsConnStub, method, msg string) { 229 t.Helper() 230 encMsg, err := json.Marshal(makeReq(method, msg)) 231 if err != nil { 232 t.Fatalf("error encoding %s request: %v", method, err) 233 } 234 conn.msg <- encMsg 235 } 236 237 func sendReplace(t *testing.T, conn *wsConnStub, thing any, old, new string) { 238 enc, err := json.Marshal(thing) 239 if err != nil { 240 t.Fatalf("error encoding thing for sendReplace: %v", err) 241 } 242 s := string(enc) 243 s = strings.ReplaceAll(s, old, new) 244 conn.msg <- []byte(s) 245 } 246 247 func newTestBisonWallet(addr string, rootCAs *x509.CertPool) (*websocket.Conn, error) { 248 uri, err := url.Parse(addr) 249 if err != nil { 250 return nil, fmt.Errorf("error parsing url: %w", err) 251 } 252 253 dialer := &websocket.Dialer{ 254 Proxy: http.ProxyFromEnvironment, // Same as DefaultDialer. 255 HandshakeTimeout: 10 * time.Second, // DefaultDialer is 45 seconds. 256 TLSClientConfig: &tls.Config{ 257 RootCAs: rootCAs, 258 InsecureSkipVerify: true, 259 ServerName: uri.Hostname(), 260 }, 261 } 262 263 conn, _, err := dialer.Dial(addr, nil) 264 if err != nil { 265 return nil, err 266 } 267 return conn, nil 268 } 269 270 func TestMain(m *testing.M) { 271 var shutdown func() 272 testCtx, shutdown = context.WithCancel(context.Background()) 273 defer shutdown() 274 UseLogger(tLogger) 275 os.Exit(m.Run()) 276 } 277 278 // method strings cannot be empty. 279 func TestRoute_PanicsEmptyString(t *testing.T) { 280 defer func() { 281 if r := recover(); r == nil { 282 t.Fatalf("no panic on registering empty string method") 283 } 284 }() 285 s := newServer() 286 s.Route("", dummyRPCHandler) 287 } 288 289 // methods cannot be registered more than once. 290 func TestRoute_PanicsDoubleRegistry(t *testing.T) { 291 defer func() { 292 if r := recover(); r == nil { 293 t.Fatalf("no panic on registering empty string method") 294 } 295 }() 296 s := newServer() 297 s.Route("somemethod", dummyRPCHandler) 298 s.Route("somemethod", dummyRPCHandler) 299 } 300 301 // Test the server with a stub for the client connections. 302 func TestClientRequests(t *testing.T) { 303 server := newServer() 304 var wg sync.WaitGroup 305 defer func() { 306 server.disconnectClients() 307 wg.Wait() 308 }() 309 var client *wsLink 310 var conn *wsConnStub 311 stubAddr := dex.IPKey{} 312 copy(stubAddr[:], []byte("testaddr")) 313 sendToServer := func(method, msg string) { sendToConn(t, conn, method, msg) } 314 315 waitForShutdown := func(tag string, f func()) { 316 needCount := server.clientCount() - 1 317 f() 318 if !giveItASecond(func() bool { 319 return server.clientCount() == needCount 320 }) { 321 t.Fatalf("%s: waitForShutdown failed", tag) 322 } 323 } 324 325 // Register all methods before sending any requests. 326 // 'getclient' grabs the server's link. 327 srvChan := make(chan any) 328 server.Route("getclient", func(c Link, _ *msgjson.Message) *msgjson.Error { 329 client, ok := c.(*wsLink) 330 if !ok { 331 t.Fatalf("failed to assert client type") 332 } 333 srvChan <- client 334 return nil 335 }) 336 getClient := func() { 337 encReq, _ := json.Marshal(makeReq("getclient", `{}`)) 338 conn.msg <- encReq 339 client = readChannel(t, "getClient", srvChan).(*wsLink) 340 } 341 342 // Check request parses the request to a map of strings. 343 server.Route("checkrequest", func(c Link, msg *msgjson.Message) *msgjson.Error { 344 if string(msg.Payload) != `{"key":"value"}` { 345 t.Fatalf("wrong request: %s", string(msg.Payload)) 346 } 347 if client.id != c.ID() { 348 t.Fatalf("client ID mismatch. %d != %d", client.id, c.ID()) 349 } 350 srvChan <- nil 351 return nil 352 }) 353 // 'checkinvalid' should never be run, since the request has invalid 354 // formatting. 355 var passed bool 356 server.Route("checkinvalid", func(_ Link, _ *msgjson.Message) *msgjson.Error { 357 passed = true 358 return nil 359 }) 360 // 'error' returns an Error. 361 server.Route("error", func(_ Link, _ *msgjson.Message) *msgjson.Error { 362 return msgjson.NewError(550, "somemessage") 363 }) 364 // 'ban' quarantines the user using the RPCQuarantineClient error code. 365 server.Route("ban", func(c Link, req *msgjson.Message) *msgjson.Error { 366 rpcErr := msgjson.NewError(msgjson.RPCQuarantineClient, "user quarantined") 367 errMsg, _ := msgjson.NewResponse(req.ID, nil, rpcErr) 368 err := c.Send(errMsg) 369 if err != nil { 370 t.Fatalf("ban route send error: %v", err) 371 } 372 c.Banish() 373 return nil 374 }) 375 var httpSeen uint32 376 server.RegisterHTTP("httproute", func(thing any) (any, error) { 377 atomic.StoreUint32(&httpSeen, 1) 378 srvChan <- nil 379 return struct{}{}, nil 380 }) 381 382 // A helper function to reconnect to the server (new comm) and grab the 383 // server's link (new client). 384 reconnect := func() { 385 conn = newWsStub() 386 387 needCount := server.clientCount() + 1 388 wg.Add(1) 389 go func() { 390 defer wg.Done() 391 server.websocketHandler(testCtx, conn, stubAddr) 392 }() 393 394 if !giveItASecond(func() bool { 395 return server.clientCount() == needCount 396 }) { 397 t.Fatalf("failed to add client") 398 } 399 400 getClient() 401 } 402 403 reconnect() 404 405 // Check that the request is parsed as expected. 406 sendToServer("checkrequest", `{"key":"value"}`) 407 readChannel(t, "checkrequest", srvChan) 408 // Send invalid params, and make sure the server doesn't pass the message. The 409 // server will not disconnect the client. 410 conn.addChan() 411 412 ensureReplaceFails := func(old, new string) { 413 sendReplace(t, conn, makeReq("checkinvalid", old), old, new) 414 <-conn.recv 415 if passed { 416 t.Fatalf("invalid request passed to handler") 417 } 418 } 419 420 ensureReplaceFails(`{"a":"b"}`, "?") 421 if client.Off() { 422 t.Fatalf("client unexpectedly disconnected after invalid message") 423 } 424 425 // Send the invalid message again, but error out on the server's WriteMessage 426 // attempt. The server should disconnect the client in this case. 427 conn.setWriteErr(tErr) 428 waitForShutdown("rpc error", func() { 429 ensureReplaceFails(`{"a":"b"}`, "?") 430 }) 431 432 // Shut the client down. Check the on flag. 433 reconnect() 434 waitForShutdown("flag set", func() { 435 client.Disconnect() 436 }) 437 438 // Reconnect and try shutting down with non-EOF error. 439 reconnect() 440 waitForShutdown("non-EOF", func() { 441 nonEOF <- struct{}{} 442 }) 443 444 // Try a non-existent handler. This should not result in a disconnect. 445 reconnect() 446 conn.addChan() 447 sendToServer("nonexistent", "{}") 448 conn.wait(t, "bad path without error") 449 if client.Off() { 450 t.Fatalf("client unexpectedly disconnected after invalid method") 451 } 452 453 // Again, but with an WriteMessage error when sending error to client. This 454 // should result in a disconnection. 455 conn.setWriteErr(tErr) 456 waitForShutdown("rpc error", func() { 457 sendToServer("nonexistent", "{}") 458 conn.wait(t, "bad path with error") 459 }) 460 461 // An RPC error. No disconnect. 462 reconnect() 463 conn.addChan() 464 sendToServer("error", "{}") 465 conn.wait(t, "rpc error") 466 if client.Off() { 467 t.Fatalf("client unexpectedly disconnected after rpc error") 468 } 469 470 // Return a user quarantine error. 471 waitForShutdown("ban", func() { 472 sendToServer("ban", "{}") 473 conn.wait(t, "ban") 474 }) 475 if !server.isQuarantined(stubAddr) { 476 t.Fatalf("server has not marked client as quarantined") 477 } 478 // A call to Send should return ErrPeerDisconnected 479 if !errors.Is(client.Send(nil), ws.ErrPeerDisconnected) { 480 t.Fatalf("incorrect error for disconnected client") 481 } 482 483 // Test that an http request passes. 484 reconnect() 485 conn.addChan() 486 sendToServer("httproute", "{}") 487 readChannel(t, "httproute", srvChan) 488 if !atomic.CompareAndSwapUint32(&httpSeen, 1, 0) { 489 t.Fatalf("HTTP route not hit") 490 } 491 conn.wait(t, "http route success") 492 493 // Disable HTTP non-critical HTTP routes and try again. 494 server.EnableDataAPI(false) 495 sendToServer("httproute", "{}") 496 resp := decodeResponse(t, <-conn.recv) 497 if resp.Error == nil || resp.Error.Code != msgjson.TooManyRequestsError { 498 t.Fatalf("no or incorrect error for disabled HTTP route: %v", resp.Error) 499 } 500 if atomic.CompareAndSwapUint32(&httpSeen, 1, 0) { 501 t.Fatalf("disabled HTTP route hit") 502 } 503 504 // Make the route a critical route 505 criticalRoutes["httproute"] = true 506 sendToServer("httproute", "{}") 507 readChannel(t, "httproute", srvChan) 508 if !atomic.CompareAndSwapUint32(&httpSeen, 1, 0) { 509 t.Fatalf("critical HTTP route not hit") 510 } 511 conn.wait(t, "critical http route success") 512 513 checkParseError := func() { 514 resp := decodeResponse(t, <-conn.recv) 515 if resp.Error == nil || resp.Error.Code != msgjson.RPCParseError { 516 t.Fatalf("no error after invalid id") 517 } 518 } 519 520 // Test an invalid ID. 521 reconnect() 522 conn.addChan() 523 msg := makeReq("getclient", `{}`) 524 msg.ID = 555 525 sendReplace(t, conn, msg, "555", "{}") 526 checkParseError() 527 528 // Test null ID 529 sendReplace(t, conn, msg, "555", "null") 530 checkParseError() 531 532 } 533 534 func TestClientResponses(t *testing.T) { 535 server := newServer() 536 var client *wsLink 537 var conn *wsConnStub 538 stubAddr := dex.IPKey{} 539 copy(stubAddr[:], []byte("testaddr")) 540 541 // Register all methods before sending any requests. 542 // 'getclient' grabs the server's link. 543 srvChan := make(chan any) 544 server.Route("grabclient", func(c Link, _ *msgjson.Message) *msgjson.Error { 545 client, ok := c.(*wsLink) 546 if !ok { 547 t.Fatalf("failed to assert client type") 548 } 549 srvChan <- client 550 return nil 551 }) 552 553 getClient := func() { 554 encReq, _ := json.Marshal(makeReq("grabclient", `{}`)) 555 conn.msg <- encReq 556 client = readChannel(t, "grabclient", srvChan).(*wsLink) 557 } 558 559 sendToClient := func(route, payload string, f func(Link, *msgjson.Message), expiration time.Duration, expire func()) uint64 { 560 req := makeReq(route, payload) 561 err := client.Request(req, f, expiration, expire) 562 if err != nil { 563 t.Logf("sendToClient error: %v", err) 564 } 565 return req.ID 566 } 567 568 respondToServer := func(id uint64, msg string) { 569 encResp, err := json.Marshal(makeResp(id, msg)) 570 if err != nil { 571 t.Fatalf("error encoding %v (%T) request: %v", id, id, err) 572 } 573 conn.msg <- encResp 574 } 575 576 var wg sync.WaitGroup 577 reconnect := func() { 578 conn = newWsStub() 579 wg.Add(1) 580 go func() { 581 defer wg.Done() 582 server.websocketHandler(testCtx, conn, stubAddr) 583 }() 584 getClient() 585 } 586 reconnect() 587 588 defer func() { 589 server.disconnectClients() 590 wg.Wait() 591 }() 592 593 // Test Broadcast 594 conn.addChan() // for WriteMessage in this test 595 server.Broadcast(makeNtfn("someNote", `"blah"`)) // async conn.recv <- msg send 596 msgBytes := <-conn.recv 597 msg, err := msgjson.DecodeMessage(msgBytes) 598 if err != nil { 599 t.Fatalf("error decoding last message: %v", err) 600 } 601 var note string 602 err = json.Unmarshal(msg.Payload, ¬e) 603 if err != nil { 604 return 605 } 606 if note != "blah" { 607 t.Errorf("wrong note: %s", note) 608 } 609 610 // Send a request from the server to the client, setting a flag when the 611 // client responds. 612 id := sendToClient("looptest", `{}`, func(_ Link, _ *msgjson.Message) { 613 srvChan <- nil 614 }, time.Hour, func() {}) 615 616 // Respond to the server 617 respondToServer(id, `{}`) 618 readChannel(t, "looptest", srvChan) 619 <-conn.recv 620 621 checkParseError := func(tag string) { 622 msg, err := msgjson.DecodeMessage(<-conn.recv) 623 if err != nil { 624 t.Fatalf("error decoding last message (%s): %v", tag, err) 625 } 626 627 resp, err := msg.Response() 628 if err != nil { 629 t.Fatalf("error decoding response (%s): %v", tag, err) 630 } 631 if resp.Error == nil || resp.Error.Code != msgjson.RPCParseError { 632 t.Fatalf("no error after %s", tag) 633 } 634 } 635 636 // Test an invalid id. 637 sendReplace(t, conn, makeResp(1, `{}`), `:1`, `:0`) 638 639 checkParseError("invalid id") 640 641 // Send an invalid payload. 642 old := `{"a":"b"}` 643 sendReplace(t, conn, makeResp(id, old), old, `?`) 644 checkParseError("invalid payload") 645 646 // check the response handler expiration 647 client.respHandlers = make(map[uint64]*responseHandler) 648 expiredID := sendToClient("expiration", `{}`, func(_ Link, _ *msgjson.Message) {}, 649 200*time.Millisecond, func() { t.Log("Expired (good).") }) 650 <-conn.recv 651 // The responseHandler map should contain the ntfn ID since expiry has not 652 // yet arrived. 653 client.reqMtx.Lock() 654 _, found := client.respHandlers[expiredID] 655 if !found { 656 t.Fatalf("response handler not found") 657 } 658 if len(client.respHandlers) != 1 { 659 t.Fatalf("expected 1 response handler, found %d", len(client.respHandlers)) 660 } 661 client.reqMtx.Unlock() 662 663 time.Sleep(250 * time.Millisecond) // >> 200ms - 10ms 664 client.reqMtx.Lock() 665 if len(client.respHandlers) != 0 { 666 t.Fatalf("expired response handler not pruned") 667 } 668 _, found = client.respHandlers[expiredID] 669 if found { 670 t.Fatalf("expired response handler still in map") 671 } 672 client.reqMtx.Unlock() 673 } 674 675 func TestOnline(t *testing.T) { 676 tempDir := t.TempDir() 677 678 keyPath := filepath.Join(tempDir, "rpc.key") 679 certPath := filepath.Join(tempDir, "rpc.cert") 680 pongWait = time.Millisecond * 500 681 pingPeriod = (pongWait * 9) / 10 682 server, err := NewServer(&RPCConfig{ 683 ListenAddrs: []string{"127.0.0.1:0"}, 684 RPCKey: keyPath, 685 RPCCert: certPath, 686 }) 687 if err != nil { 688 t.Fatalf("server constructor error: %v", err) 689 } 690 address := "wss://" + server.listeners[0].Addr().String() + "/ws" 691 692 // Register routes before starting server. 693 // The 'ok' route returns an affirmative response. 694 type okresult struct { 695 OK bool `json:"ok"` 696 } 697 server.Route("ok", func(c Link, msg *msgjson.Message) *msgjson.Error { 698 resp, err := msgjson.NewResponse(msg.ID, &okresult{OK: true}, nil) 699 if err != nil { 700 return msgjson.NewError(500, "%v", err) 701 } 702 err = c.Send(resp) 703 if err != nil { 704 return msgjson.NewError(500, "%v", err) 705 } 706 return nil 707 }) 708 // The 'banuser' route quarantines the user. 709 banChan := make(chan any) 710 server.Route("banuser", func(c Link, req *msgjson.Message) *msgjson.Error { 711 rpcErr := msgjson.NewError(msgjson.RPCQuarantineClient, "test quarantine") 712 msg, _ := msgjson.NewResponse(req.ID, nil, rpcErr) 713 err := c.Send(msg) 714 if err != nil { 715 t.Fatalf("banuser route send error: %v", err) 716 } 717 c.Banish() 718 banChan <- nil 719 return nil 720 }) 721 722 ssw := dex.NewStartStopWaiter(server) 723 ssw.Start(testCtx) 724 defer func() { 725 ssw.Stop() 726 ssw.WaitForShutdown() 727 }() 728 729 // Get the SystemCertPool, continue with an empty pool on error 730 rootCAs, _ := x509.SystemCertPool() 731 if rootCAs == nil { 732 rootCAs = x509.NewCertPool() 733 } 734 735 // Read in the cert file 736 certs, err := os.ReadFile(certPath) 737 if err != nil { 738 t.Fatalf("Failed to append %q to RootCAs: %v", certPath, err) 739 } 740 741 // Append our cert to the system pool 742 if ok := rootCAs.AppendCertsFromPEM(certs); !ok { 743 t.Fatalf("No certs appended, using system certs only") 744 } 745 746 remoteClient, err := newTestBisonWallet(address, rootCAs) 747 if err != nil { 748 t.Fatalf("remoteClient constructor error: %v", err) 749 } 750 751 // A loop to grab responses from the server. 752 recv := make(chan any) 753 go func() { 754 for { 755 _, r, err := remoteClient.ReadMessage() 756 if err == nil { 757 recv <- r 758 } else { 759 recv <- err 760 break 761 } 762 } 763 }() 764 765 sendToDEX := func(route, msg string) error { 766 b, err := json.Marshal(makeReq(route, msg)) 767 if err != nil { 768 t.Fatalf("error encoding %s request: %v", route, err) 769 } 770 err = remoteClient.WriteMessage(websocket.TextMessage, b) 771 return err 772 } 773 774 // Sleep for a couple of pongs to make sure the client doesn't disconnect. 775 time.Sleep(pongWait * 2) 776 777 // Positive path. 778 err = sendToDEX("ok", "{}") 779 if err != nil { 780 t.Fatalf("noresponse send error: %v", err) 781 } 782 b := readChannel(t, "ok", recv).([]byte) 783 784 msg, _ := msgjson.DecodeMessage(b) 785 786 ok := new(okresult) 787 err = msg.UnmarshalResult(ok) 788 if err != nil { 789 t.Fatalf("'ok' response unmarshal error: %v", err) 790 } 791 if !ok.OK { 792 t.Fatalf("ok.OK false") 793 } 794 795 // Ban the client using the special Error code. 796 err = sendToDEX("banuser", "{}") 797 if err != nil { 798 t.Fatalf("banuser send error: %v", err) 799 } 800 // Just for sequencing 801 readChannel(t, "noresponse", banChan) 802 803 msgB := readChannel(t, "banuser msg", recv).([]byte) 804 if !strings.Contains(string(msgB), "test quarantine") { 805 t.Fatalf("wrong ban message received: %s", string(msgB)) 806 } 807 808 err = readChannel(t, "banuser err", recv).(error) 809 if err == nil { 810 t.Fatalf("no read error after ban") 811 } 812 813 // Try connecting, and make sure there is an error. 814 _, err = newTestBisonWallet(address, rootCAs) 815 if err == nil { 816 t.Fatalf("no websocket connection error after ban") 817 } 818 // Manually set the ban time. 819 server.banMtx.Lock() 820 if len(server.quarantine) != 1 { 821 t.Fatalf("unexpected number of quarantined IPs") 822 } 823 for ip := range server.quarantine { 824 server.quarantine[ip] = time.Now() 825 } 826 server.banMtx.Unlock() 827 // Now try again. Should connect. 828 conn, err := newTestBisonWallet(address, rootCAs) 829 if err != nil { 830 t.Fatalf("error connecting on expired ban") 831 } 832 var clientCount uint64 833 if !giveItASecond(func() bool { 834 clientCount = server.clientCount() 835 return clientCount == 1 836 }) { 837 t.Fatalf("server claiming %d clients. Expected 1", clientCount) 838 } 839 conn.Close() 840 } 841 842 func TestParseListeners(t *testing.T) { 843 ipv6wPort := "[fdc5:f621:d3b4:923f::]:80" 844 ipv6wZonePort := "[a:b:c:d::%123]:45" 845 // Invalid because capital letter O. 846 ipv6Invalid := "[1200:0000:AB00:1234:O000:2552:7777:1313]:1234" 847 ipv4wPort := "36.182.54.55:80" 848 849 ips := []string{ 850 ipv6wPort, 851 ipv6wZonePort, 852 ipv4wPort, 853 } 854 855 out4, out6, hasWildcard, err := parseListeners(ips) 856 if err != nil { 857 t.Fatalf("error parsing listeners: %v", err) 858 } 859 if len(out4) != 1 { 860 t.Fatalf("expected 1 ipv4 addresses. found %d", len(out4)) 861 } 862 if len(out6) != 2 { 863 t.Fatalf("expected 2 ipv6 addresses. found %d", len(out6)) 864 } 865 if hasWildcard { 866 t.Fatal("hasWildcard true. should be false.") 867 } 868 869 // Port-only address goes in both. 870 ips = append(ips, ":1234") 871 out4, out6, hasWildcard, err = parseListeners(ips) 872 if err != nil { 873 t.Fatalf("error parsing listeners with wildcard: %v", err) 874 } 875 if len(out4) != 2 { 876 t.Fatalf("expected 2 ipv4 addresses. found %d", len(out4)) 877 } 878 if len(out6) != 3 { 879 t.Fatalf("expected 3 ipv6 addresses. found %d", len(out6)) 880 } 881 if !hasWildcard { 882 t.Fatal("hasWildcard false with port-only address") 883 } 884 885 // No port is invalid 886 ips = append(ips, "localhost") 887 _, _, _, err = parseListeners(ips) 888 if err == nil { 889 t.Fatal("no error when no IP specified") 890 } 891 892 // Pass invalid address 893 _, _, _, err = parseListeners([]string{ipv6Invalid}) 894 if err == nil { 895 t.Fatal("no error with invalid address") 896 } 897 } 898 899 type tHTTPHandler struct { 900 count uint32 901 } 902 903 func (h *tHTTPHandler) ServeHTTP(http.ResponseWriter, *http.Request) { 904 atomic.AddUint32(&h.count, 1) 905 } 906 907 func TestHTTPRateLimiter(t *testing.T) { 908 tHandler := &tHTTPHandler{} 909 s := Server{dataEnabled: 1} 910 911 f := s.LimitRate(tHandler) 912 ip := "ip" 913 req := &http.Request{RemoteAddr: ip} 914 recorder := httptest.NewRecorder() 915 for i := 0; i < ipMaxBurstSize; i++ { 916 f.ServeHTTP(recorder, req) 917 } 918 time.Sleep(100 * time.Millisecond) 919 f.ServeHTTP(recorder, req) 920 successes := atomic.LoadUint32(&tHandler.count) 921 if successes != ipMaxBurstSize { 922 t.Fatalf("expected %d requests. got %d", ipMaxBurstSize, successes) 923 } 924 statusCode := recorder.Result().StatusCode 925 if statusCode != http.StatusTooManyRequests { 926 t.Fatalf("wrong status code. wanted %d, got %d", http.StatusTooManyRequests, statusCode) 927 } 928 } 929 930 func TestWSRateLimiter(t *testing.T) { 931 server := newServer() 932 var wg sync.WaitGroup 933 defer func() { 934 server.disconnectClients() 935 wg.Wait() 936 }() 937 938 handled := make(chan struct{}, 1) 939 940 server.Route(msgjson.FeeRateRoute, func(Link, *msgjson.Message) *msgjson.Error { 941 handled <- struct{}{} 942 return nil 943 }) 944 945 server.Route(msgjson.OrderBookRoute, func(Link, *msgjson.Message) *msgjson.Error { 946 handled <- struct{}{} 947 return nil 948 }) 949 950 conn := newWsStub() 951 conn.addChan() // for <-conn.recv 952 conn.addNextChan() // for <-conn.nextRead, each time ReadMessage is called 953 954 wg.Add(1) 955 go func(conn *wsConnStub) { 956 defer wg.Done() 957 stubAddr := dex.NewIPKey("aabb:cc:ddee:ff::abc") // "abc" chopped by NewIPKey, "ff" chopped by PrefixV6 958 if stubAddr.IsUnspecified() { 959 t.Errorf("bad addr") 960 return 961 } 962 server.websocketHandler(testCtx, conn, stubAddr) // newWSLink -> Connect -> readloop will call handleMessage 963 close(conn.nextRead) // must be after read loop has quit (sends on nextRead) 964 }(conn) 965 966 <-conn.nextRead 967 go func() { // connected, so just keep receiving on the channel 968 for range conn.nextRead { 969 } 970 }() 971 972 waitResult := func() int { 973 t.Helper() 974 select { 975 case <-handled: 976 return 0 977 case resp := <-conn.recv: // test handlers only return resp with error (rate limit) 978 // t.Log("handler error message:", string(resp)) 979 msg, err := msgjson.DecodeMessage(resp) 980 if err != nil { 981 t.Fatalf("failed to decode response message: %v", err) 982 } 983 payload, err := msg.Response() 984 if err != nil { 985 t.Fatalf("failed to decode response: %v", err) 986 } 987 if payload.Error == nil { 988 t.Fatalf("Expected rate limiting error, got none.") 989 } 990 if payload.Error.Code != msgjson.TooManyRequestsError { 991 t.Fatalf("Wanted code %d, got %d.", msgjson.TooManyRequestsError, payload.Error.Code) 992 } 993 if !strings.HasPrefix(payload.Error.Message, "too many requests") { 994 t.Fatalf("Wanted message with prefix %q, got %q.", "too many requests", payload.Error.Message) 995 } 996 return 1 997 case <-time.After(5 * time.Second): 998 t.Fatal("timeout") 999 } 1000 return 2 1001 } 1002 1003 // Other routes still work. 1004 sendToConn(t, conn, msgjson.FeeRateRoute, `{}`) 1005 if waitResult() != 0 { 1006 t.Fatalf("fee_rate request failed") 1007 } 1008 1009 // orderbook, which has 1 r/s rate limit, 100 burst 1010 sendToConn(t, conn, msgjson.OrderBookRoute, `{}`) 1011 if waitResult() != 0 { 1012 t.Fatalf("orderbook request failed") 1013 } 1014 sendToConn(t, conn, msgjson.OrderBookRoute, `{}`) 1015 if waitResult() != 0 { // tests burst > 1 1016 t.Fatalf("orderbook request failed") 1017 } 1018 1019 // New connection from different address. 1020 conn = newWsStub() 1021 conn.addChan() // for <-conn.recv 1022 conn.addNextChan() // for <-conn.nextRead, each time ReadMessage is called 1023 1024 wg.Add(1) 1025 go func(conn *wsConnStub) { 1026 defer wg.Done() 1027 stubAddr := dex.NewIPKey("aabb:cc:ddee:11::") // same prefix, different subnet 1028 if stubAddr.IsUnspecified() { 1029 t.Errorf("bad addr") 1030 return 1031 } 1032 server.websocketHandler(testCtx, conn, stubAddr) // newWSLink -> Connect -> readloop will call handleMessage 1033 close(conn.nextRead) // must be after read loop has quit (sends on nextRead) 1034 }(conn) 1035 1036 <-conn.nextRead 1037 go func() { // connected, so just keep receiving on the channel 1038 for range conn.nextRead { 1039 } 1040 }() 1041 }