decred.org/dcrdex@v1.0.5/server/comms/comms_test.go (about)

     1  package comms
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/tls"
     7  	"crypto/x509"
     8  	"encoding/json"
     9  	"errors"
    10  	"fmt"
    11  	"net/http"
    12  	"net/http/httptest"
    13  	"net/url"
    14  	"os"
    15  	"path/filepath"
    16  	"strings"
    17  	"sync"
    18  	"sync/atomic"
    19  	"testing"
    20  	"time"
    21  
    22  	"decred.org/dcrdex/dex"
    23  	"decred.org/dcrdex/dex/msgjson"
    24  	"decred.org/dcrdex/dex/ws"
    25  	"github.com/gorilla/websocket"
    26  )
    27  
    28  var (
    29  	tErr    = fmt.Errorf("test error")
    30  	testCtx context.Context
    31  	tLogger = dex.StdOutLogger("TCOMMS", dex.LevelTrace)
    32  )
    33  
    34  func newServer() *Server {
    35  	s := &Server{
    36  		clients:     make(map[uint64]*wsLink),
    37  		wsLimiters:  make(map[dex.IPKey]*ipWsLimiter),
    38  		v6Prefixes:  make(map[dex.IPKey]int),
    39  		quarantine:  make(map[dex.IPKey]time.Time),
    40  		dataEnabled: 1,
    41  		rpcRoutes:   make(map[string]MsgHandler),
    42  		httpRoutes:  make(map[string]HTTPHandler),
    43  	}
    44  	for _, route := range []string{msgjson.ConfigRoute, msgjson.SpotsRoute, msgjson.CandlesRoute, msgjson.OrderBookRoute} {
    45  		s.RegisterHTTP(route, func(any) (any, error) { return nil, nil })
    46  	}
    47  	return s
    48  }
    49  
    50  func giveItASecond(f func() bool) bool {
    51  	ticker := time.NewTicker(time.Millisecond)
    52  	timeout := time.NewTimer(time.Second)
    53  	for {
    54  		if f() {
    55  			return true
    56  		}
    57  		select {
    58  		case <-timeout.C:
    59  			return false
    60  		default:
    61  		}
    62  		<-ticker.C
    63  	}
    64  }
    65  
    66  func readChannel(t *testing.T, tag string, c chan any) any {
    67  	t.Helper()
    68  	select {
    69  	case i := <-c:
    70  		return i
    71  	case <-time.NewTimer(time.Second).C:
    72  		t.Fatalf("%s: didn't read channel", tag)
    73  	}
    74  	return nil
    75  }
    76  
    77  func decodeResponse(t *testing.T, b []byte) *msgjson.ResponsePayload {
    78  	t.Helper()
    79  	msg, err := msgjson.DecodeMessage(b)
    80  	if err != nil {
    81  		t.Fatalf("error decoding last message (%s): %v", string(b), err)
    82  	}
    83  	resp, err := msg.Response()
    84  	if err != nil {
    85  		t.Fatalf("error decoding response payload: %v", err)
    86  	}
    87  	return resp
    88  }
    89  
    90  type wsConnStub struct {
    91  	msg      chan []byte
    92  	quit     chan struct{}
    93  	close    int
    94  	recv     chan []byte
    95  	nextRead chan struct{} // helps detect when (*WSLink).inHandler is running
    96  	writeMtx sync.Mutex
    97  	writeErr error
    98  }
    99  
   100  func (conn *wsConnStub) addChan() {
   101  	conn.recv = make(chan []byte)
   102  }
   103  
   104  func (conn *wsConnStub) addNextChan() {
   105  	conn.nextRead = make(chan struct{}, 1) // send when ReadMessage() is called
   106  }
   107  
   108  func (conn *wsConnStub) wait(t *testing.T, tag string) {
   109  	t.Helper()
   110  	select {
   111  	case <-conn.recv:
   112  	case <-time.NewTimer(time.Second).C:
   113  		t.Fatalf("%s - wait timeout", tag)
   114  	}
   115  }
   116  
   117  func newWsStub() *wsConnStub {
   118  	return &wsConnStub{
   119  		msg: make(chan []byte),
   120  		// recv is nil unless a test wants to receive
   121  		quit: make(chan struct{}),
   122  	}
   123  }
   124  
   125  func (conn *wsConnStub) setWriteErr(err error) {
   126  	conn.writeMtx.Lock()
   127  	conn.writeErr = err
   128  	conn.writeMtx.Unlock()
   129  }
   130  
   131  // nonEOF can specify a particular error should be returned through ReadMessage.
   132  var nonEOF = make(chan struct{})
   133  var pongTrigger = []byte("pong")
   134  
   135  func (conn *wsConnStub) ReadMessage() (int, []byte, error) {
   136  	if conn.nextRead != nil {
   137  		conn.nextRead <- struct{}{}
   138  	}
   139  
   140  	var b []byte
   141  	select {
   142  	case b = <-conn.msg:
   143  		if bytes.Equal(b, pongTrigger) {
   144  			return websocket.PongMessage, []byte{}, nil
   145  		}
   146  	case <-conn.quit:
   147  		return 0, nil, &websocket.CloseError{Code: websocket.CloseGoingAway, Text: "bye"}
   148  	case <-testCtx.Done():
   149  		return 0, nil, &websocket.CloseError{Code: websocket.CloseGoingAway, Text: "bye"}
   150  	case <-nonEOF:
   151  		close(conn.quit)
   152  		return 0, nil, fmt.Errorf("test nonEOF error")
   153  	}
   154  	return 0, b, nil
   155  }
   156  
   157  func (conn *wsConnStub) WriteMessage(msgType int, msg []byte) error {
   158  	conn.writeMtx.Lock()
   159  	defer conn.writeMtx.Unlock()
   160  	if msgType == websocket.PingMessage {
   161  		select {
   162  		case conn.msg <- pongTrigger:
   163  		default:
   164  		}
   165  		return nil
   166  	}
   167  	// Send the message if their is a receiver for the current test.
   168  	if conn.recv != nil {
   169  		conn.recv <- msg
   170  	}
   171  	if conn.writeErr == nil {
   172  		return nil
   173  	}
   174  	err := conn.writeErr
   175  	conn.writeErr = nil
   176  	return err
   177  }
   178  
   179  func (conn *wsConnStub) SetReadLimit(int64) {}
   180  
   181  func (conn *wsConnStub) SetWriteDeadline(t time.Time) error {
   182  	return nil // TODO implement and test write timeouts
   183  }
   184  
   185  func (conn *wsConnStub) SetReadDeadline(t time.Time) error {
   186  	return nil
   187  }
   188  
   189  func (conn *wsConnStub) WriteControl(messageType int, data []byte, deadline time.Time) error {
   190  	return nil
   191  }
   192  
   193  func (conn *wsConnStub) Close() error {
   194  	select {
   195  	case <-conn.quit:
   196  	default:
   197  		close(conn.quit)
   198  	}
   199  	conn.close++
   200  	return nil
   201  }
   202  
   203  func dummyRPCHandler(_ Link, _ *msgjson.Message) *msgjson.Error {
   204  	return nil
   205  }
   206  
   207  var reqID uint64
   208  
   209  func makeReq(route, msg string) *msgjson.Message {
   210  	reqID++
   211  	req, err := msgjson.NewRequest(reqID, route, json.RawMessage(msg))
   212  	if err != nil {
   213  		panic("bad request message")
   214  	}
   215  	return req
   216  }
   217  
   218  func makeResp(id uint64, msg string) *msgjson.Message {
   219  	resp, _ := msgjson.NewResponse(id, json.RawMessage(msg), nil)
   220  	return resp
   221  }
   222  
   223  func makeNtfn(route, msg string) *msgjson.Message {
   224  	ntfn, _ := msgjson.NewNotification(route, json.RawMessage(msg))
   225  	return ntfn
   226  }
   227  
   228  func sendToConn(t *testing.T, conn *wsConnStub, method, msg string) {
   229  	t.Helper()
   230  	encMsg, err := json.Marshal(makeReq(method, msg))
   231  	if err != nil {
   232  		t.Fatalf("error encoding %s request: %v", method, err)
   233  	}
   234  	conn.msg <- encMsg
   235  }
   236  
   237  func sendReplace(t *testing.T, conn *wsConnStub, thing any, old, new string) {
   238  	enc, err := json.Marshal(thing)
   239  	if err != nil {
   240  		t.Fatalf("error encoding thing for sendReplace: %v", err)
   241  	}
   242  	s := string(enc)
   243  	s = strings.ReplaceAll(s, old, new)
   244  	conn.msg <- []byte(s)
   245  }
   246  
   247  func newTestBisonWallet(addr string, rootCAs *x509.CertPool) (*websocket.Conn, error) {
   248  	uri, err := url.Parse(addr)
   249  	if err != nil {
   250  		return nil, fmt.Errorf("error parsing url: %w", err)
   251  	}
   252  
   253  	dialer := &websocket.Dialer{
   254  		Proxy:            http.ProxyFromEnvironment, // Same as DefaultDialer.
   255  		HandshakeTimeout: 10 * time.Second,          // DefaultDialer is 45 seconds.
   256  		TLSClientConfig: &tls.Config{
   257  			RootCAs:            rootCAs,
   258  			InsecureSkipVerify: true,
   259  			ServerName:         uri.Hostname(),
   260  		},
   261  	}
   262  
   263  	conn, _, err := dialer.Dial(addr, nil)
   264  	if err != nil {
   265  		return nil, err
   266  	}
   267  	return conn, nil
   268  }
   269  
   270  func TestMain(m *testing.M) {
   271  	var shutdown func()
   272  	testCtx, shutdown = context.WithCancel(context.Background())
   273  	defer shutdown()
   274  	UseLogger(tLogger)
   275  	os.Exit(m.Run())
   276  }
   277  
   278  // method strings cannot be empty.
   279  func TestRoute_PanicsEmptyString(t *testing.T) {
   280  	defer func() {
   281  		if r := recover(); r == nil {
   282  			t.Fatalf("no panic on registering empty string method")
   283  		}
   284  	}()
   285  	s := newServer()
   286  	s.Route("", dummyRPCHandler)
   287  }
   288  
   289  // methods cannot be registered more than once.
   290  func TestRoute_PanicsDoubleRegistry(t *testing.T) {
   291  	defer func() {
   292  		if r := recover(); r == nil {
   293  			t.Fatalf("no panic on registering empty string method")
   294  		}
   295  	}()
   296  	s := newServer()
   297  	s.Route("somemethod", dummyRPCHandler)
   298  	s.Route("somemethod", dummyRPCHandler)
   299  }
   300  
   301  // Test the server with a stub for the client connections.
   302  func TestClientRequests(t *testing.T) {
   303  	server := newServer()
   304  	var wg sync.WaitGroup
   305  	defer func() {
   306  		server.disconnectClients()
   307  		wg.Wait()
   308  	}()
   309  	var client *wsLink
   310  	var conn *wsConnStub
   311  	stubAddr := dex.IPKey{}
   312  	copy(stubAddr[:], []byte("testaddr"))
   313  	sendToServer := func(method, msg string) { sendToConn(t, conn, method, msg) }
   314  
   315  	waitForShutdown := func(tag string, f func()) {
   316  		needCount := server.clientCount() - 1
   317  		f()
   318  		if !giveItASecond(func() bool {
   319  			return server.clientCount() == needCount
   320  		}) {
   321  			t.Fatalf("%s: waitForShutdown failed", tag)
   322  		}
   323  	}
   324  
   325  	// Register all methods before sending any requests.
   326  	// 'getclient' grabs the server's link.
   327  	srvChan := make(chan any)
   328  	server.Route("getclient", func(c Link, _ *msgjson.Message) *msgjson.Error {
   329  		client, ok := c.(*wsLink)
   330  		if !ok {
   331  			t.Fatalf("failed to assert client type")
   332  		}
   333  		srvChan <- client
   334  		return nil
   335  	})
   336  	getClient := func() {
   337  		encReq, _ := json.Marshal(makeReq("getclient", `{}`))
   338  		conn.msg <- encReq
   339  		client = readChannel(t, "getClient", srvChan).(*wsLink)
   340  	}
   341  
   342  	// Check request parses the request to a map of strings.
   343  	server.Route("checkrequest", func(c Link, msg *msgjson.Message) *msgjson.Error {
   344  		if string(msg.Payload) != `{"key":"value"}` {
   345  			t.Fatalf("wrong request: %s", string(msg.Payload))
   346  		}
   347  		if client.id != c.ID() {
   348  			t.Fatalf("client ID mismatch. %d != %d", client.id, c.ID())
   349  		}
   350  		srvChan <- nil
   351  		return nil
   352  	})
   353  	// 'checkinvalid' should never be run, since the request has invalid
   354  	// formatting.
   355  	var passed bool
   356  	server.Route("checkinvalid", func(_ Link, _ *msgjson.Message) *msgjson.Error {
   357  		passed = true
   358  		return nil
   359  	})
   360  	// 'error' returns an Error.
   361  	server.Route("error", func(_ Link, _ *msgjson.Message) *msgjson.Error {
   362  		return msgjson.NewError(550, "somemessage")
   363  	})
   364  	// 'ban' quarantines the user using the RPCQuarantineClient error code.
   365  	server.Route("ban", func(c Link, req *msgjson.Message) *msgjson.Error {
   366  		rpcErr := msgjson.NewError(msgjson.RPCQuarantineClient, "user quarantined")
   367  		errMsg, _ := msgjson.NewResponse(req.ID, nil, rpcErr)
   368  		err := c.Send(errMsg)
   369  		if err != nil {
   370  			t.Fatalf("ban route send error: %v", err)
   371  		}
   372  		c.Banish()
   373  		return nil
   374  	})
   375  	var httpSeen uint32
   376  	server.RegisterHTTP("httproute", func(thing any) (any, error) {
   377  		atomic.StoreUint32(&httpSeen, 1)
   378  		srvChan <- nil
   379  		return struct{}{}, nil
   380  	})
   381  
   382  	// A helper function to reconnect to the server (new comm) and grab the
   383  	// server's link (new client).
   384  	reconnect := func() {
   385  		conn = newWsStub()
   386  
   387  		needCount := server.clientCount() + 1
   388  		wg.Add(1)
   389  		go func() {
   390  			defer wg.Done()
   391  			server.websocketHandler(testCtx, conn, stubAddr)
   392  		}()
   393  
   394  		if !giveItASecond(func() bool {
   395  			return server.clientCount() == needCount
   396  		}) {
   397  			t.Fatalf("failed to add client")
   398  		}
   399  
   400  		getClient()
   401  	}
   402  
   403  	reconnect()
   404  
   405  	// Check that the request is parsed as expected.
   406  	sendToServer("checkrequest", `{"key":"value"}`)
   407  	readChannel(t, "checkrequest", srvChan)
   408  	// Send invalid params, and make sure the server doesn't pass the message. The
   409  	// server will not disconnect the client.
   410  	conn.addChan()
   411  
   412  	ensureReplaceFails := func(old, new string) {
   413  		sendReplace(t, conn, makeReq("checkinvalid", old), old, new)
   414  		<-conn.recv
   415  		if passed {
   416  			t.Fatalf("invalid request passed to handler")
   417  		}
   418  	}
   419  
   420  	ensureReplaceFails(`{"a":"b"}`, "?")
   421  	if client.Off() {
   422  		t.Fatalf("client unexpectedly disconnected after invalid message")
   423  	}
   424  
   425  	// Send the invalid message again, but error out on the server's WriteMessage
   426  	// attempt. The server should disconnect the client in this case.
   427  	conn.setWriteErr(tErr)
   428  	waitForShutdown("rpc error", func() {
   429  		ensureReplaceFails(`{"a":"b"}`, "?")
   430  	})
   431  
   432  	// Shut the client down. Check the on flag.
   433  	reconnect()
   434  	waitForShutdown("flag set", func() {
   435  		client.Disconnect()
   436  	})
   437  
   438  	// Reconnect and try shutting down with non-EOF error.
   439  	reconnect()
   440  	waitForShutdown("non-EOF", func() {
   441  		nonEOF <- struct{}{}
   442  	})
   443  
   444  	// Try a non-existent handler. This should not result in a disconnect.
   445  	reconnect()
   446  	conn.addChan()
   447  	sendToServer("nonexistent", "{}")
   448  	conn.wait(t, "bad path without error")
   449  	if client.Off() {
   450  		t.Fatalf("client unexpectedly disconnected after invalid method")
   451  	}
   452  
   453  	// Again, but with an WriteMessage error when sending error to client. This
   454  	// should result in a disconnection.
   455  	conn.setWriteErr(tErr)
   456  	waitForShutdown("rpc error", func() {
   457  		sendToServer("nonexistent", "{}")
   458  		conn.wait(t, "bad path with error")
   459  	})
   460  
   461  	// An RPC error. No disconnect.
   462  	reconnect()
   463  	conn.addChan()
   464  	sendToServer("error", "{}")
   465  	conn.wait(t, "rpc error")
   466  	if client.Off() {
   467  		t.Fatalf("client unexpectedly disconnected after rpc error")
   468  	}
   469  
   470  	// Return a user quarantine error.
   471  	waitForShutdown("ban", func() {
   472  		sendToServer("ban", "{}")
   473  		conn.wait(t, "ban")
   474  	})
   475  	if !server.isQuarantined(stubAddr) {
   476  		t.Fatalf("server has not marked client as quarantined")
   477  	}
   478  	// A call to Send should return ErrPeerDisconnected
   479  	if !errors.Is(client.Send(nil), ws.ErrPeerDisconnected) {
   480  		t.Fatalf("incorrect error for disconnected client")
   481  	}
   482  
   483  	// Test that an http request passes.
   484  	reconnect()
   485  	conn.addChan()
   486  	sendToServer("httproute", "{}")
   487  	readChannel(t, "httproute", srvChan)
   488  	if !atomic.CompareAndSwapUint32(&httpSeen, 1, 0) {
   489  		t.Fatalf("HTTP route not hit")
   490  	}
   491  	conn.wait(t, "http route success")
   492  
   493  	// Disable HTTP non-critical HTTP routes and try again.
   494  	server.EnableDataAPI(false)
   495  	sendToServer("httproute", "{}")
   496  	resp := decodeResponse(t, <-conn.recv)
   497  	if resp.Error == nil || resp.Error.Code != msgjson.TooManyRequestsError {
   498  		t.Fatalf("no or incorrect error for disabled HTTP route: %v", resp.Error)
   499  	}
   500  	if atomic.CompareAndSwapUint32(&httpSeen, 1, 0) {
   501  		t.Fatalf("disabled HTTP route hit")
   502  	}
   503  
   504  	// Make the route a critical route
   505  	criticalRoutes["httproute"] = true
   506  	sendToServer("httproute", "{}")
   507  	readChannel(t, "httproute", srvChan)
   508  	if !atomic.CompareAndSwapUint32(&httpSeen, 1, 0) {
   509  		t.Fatalf("critical HTTP route not hit")
   510  	}
   511  	conn.wait(t, "critical http route success")
   512  
   513  	checkParseError := func() {
   514  		resp := decodeResponse(t, <-conn.recv)
   515  		if resp.Error == nil || resp.Error.Code != msgjson.RPCParseError {
   516  			t.Fatalf("no error after invalid id")
   517  		}
   518  	}
   519  
   520  	// Test an invalid ID.
   521  	reconnect()
   522  	conn.addChan()
   523  	msg := makeReq("getclient", `{}`)
   524  	msg.ID = 555
   525  	sendReplace(t, conn, msg, "555", "{}")
   526  	checkParseError()
   527  
   528  	// Test null ID
   529  	sendReplace(t, conn, msg, "555", "null")
   530  	checkParseError()
   531  
   532  }
   533  
   534  func TestClientResponses(t *testing.T) {
   535  	server := newServer()
   536  	var client *wsLink
   537  	var conn *wsConnStub
   538  	stubAddr := dex.IPKey{}
   539  	copy(stubAddr[:], []byte("testaddr"))
   540  
   541  	// Register all methods before sending any requests.
   542  	// 'getclient' grabs the server's link.
   543  	srvChan := make(chan any)
   544  	server.Route("grabclient", func(c Link, _ *msgjson.Message) *msgjson.Error {
   545  		client, ok := c.(*wsLink)
   546  		if !ok {
   547  			t.Fatalf("failed to assert client type")
   548  		}
   549  		srvChan <- client
   550  		return nil
   551  	})
   552  
   553  	getClient := func() {
   554  		encReq, _ := json.Marshal(makeReq("grabclient", `{}`))
   555  		conn.msg <- encReq
   556  		client = readChannel(t, "grabclient", srvChan).(*wsLink)
   557  	}
   558  
   559  	sendToClient := func(route, payload string, f func(Link, *msgjson.Message), expiration time.Duration, expire func()) uint64 {
   560  		req := makeReq(route, payload)
   561  		err := client.Request(req, f, expiration, expire)
   562  		if err != nil {
   563  			t.Logf("sendToClient error: %v", err)
   564  		}
   565  		return req.ID
   566  	}
   567  
   568  	respondToServer := func(id uint64, msg string) {
   569  		encResp, err := json.Marshal(makeResp(id, msg))
   570  		if err != nil {
   571  			t.Fatalf("error encoding %v (%T) request: %v", id, id, err)
   572  		}
   573  		conn.msg <- encResp
   574  	}
   575  
   576  	var wg sync.WaitGroup
   577  	reconnect := func() {
   578  		conn = newWsStub()
   579  		wg.Add(1)
   580  		go func() {
   581  			defer wg.Done()
   582  			server.websocketHandler(testCtx, conn, stubAddr)
   583  		}()
   584  		getClient()
   585  	}
   586  	reconnect()
   587  
   588  	defer func() {
   589  		server.disconnectClients()
   590  		wg.Wait()
   591  	}()
   592  
   593  	// Test Broadcast
   594  	conn.addChan()                                   // for WriteMessage in this test
   595  	server.Broadcast(makeNtfn("someNote", `"blah"`)) // async conn.recv <- msg send
   596  	msgBytes := <-conn.recv
   597  	msg, err := msgjson.DecodeMessage(msgBytes)
   598  	if err != nil {
   599  		t.Fatalf("error decoding last message: %v", err)
   600  	}
   601  	var note string
   602  	err = json.Unmarshal(msg.Payload, &note)
   603  	if err != nil {
   604  		return
   605  	}
   606  	if note != "blah" {
   607  		t.Errorf("wrong note: %s", note)
   608  	}
   609  
   610  	// Send a request from the server to the client, setting a flag when the
   611  	// client responds.
   612  	id := sendToClient("looptest", `{}`, func(_ Link, _ *msgjson.Message) {
   613  		srvChan <- nil
   614  	}, time.Hour, func() {})
   615  
   616  	// Respond to the server
   617  	respondToServer(id, `{}`)
   618  	readChannel(t, "looptest", srvChan)
   619  	<-conn.recv
   620  
   621  	checkParseError := func(tag string) {
   622  		msg, err := msgjson.DecodeMessage(<-conn.recv)
   623  		if err != nil {
   624  			t.Fatalf("error decoding last message (%s): %v", tag, err)
   625  		}
   626  
   627  		resp, err := msg.Response()
   628  		if err != nil {
   629  			t.Fatalf("error decoding response (%s): %v", tag, err)
   630  		}
   631  		if resp.Error == nil || resp.Error.Code != msgjson.RPCParseError {
   632  			t.Fatalf("no error after %s", tag)
   633  		}
   634  	}
   635  
   636  	// Test an invalid id.
   637  	sendReplace(t, conn, makeResp(1, `{}`), `:1`, `:0`)
   638  
   639  	checkParseError("invalid id")
   640  
   641  	// Send an invalid payload.
   642  	old := `{"a":"b"}`
   643  	sendReplace(t, conn, makeResp(id, old), old, `?`)
   644  	checkParseError("invalid payload")
   645  
   646  	// check the response handler expiration
   647  	client.respHandlers = make(map[uint64]*responseHandler)
   648  	expiredID := sendToClient("expiration", `{}`, func(_ Link, _ *msgjson.Message) {},
   649  		200*time.Millisecond, func() { t.Log("Expired (good).") })
   650  	<-conn.recv
   651  	// The responseHandler map should contain the ntfn ID since expiry has not
   652  	// yet arrived.
   653  	client.reqMtx.Lock()
   654  	_, found := client.respHandlers[expiredID]
   655  	if !found {
   656  		t.Fatalf("response handler not found")
   657  	}
   658  	if len(client.respHandlers) != 1 {
   659  		t.Fatalf("expected 1 response handler, found %d", len(client.respHandlers))
   660  	}
   661  	client.reqMtx.Unlock()
   662  
   663  	time.Sleep(250 * time.Millisecond) // >> 200ms - 10ms
   664  	client.reqMtx.Lock()
   665  	if len(client.respHandlers) != 0 {
   666  		t.Fatalf("expired response handler not pruned")
   667  	}
   668  	_, found = client.respHandlers[expiredID]
   669  	if found {
   670  		t.Fatalf("expired response handler still in map")
   671  	}
   672  	client.reqMtx.Unlock()
   673  }
   674  
   675  func TestOnline(t *testing.T) {
   676  	tempDir := t.TempDir()
   677  
   678  	keyPath := filepath.Join(tempDir, "rpc.key")
   679  	certPath := filepath.Join(tempDir, "rpc.cert")
   680  	pongWait = time.Millisecond * 500
   681  	pingPeriod = (pongWait * 9) / 10
   682  	server, err := NewServer(&RPCConfig{
   683  		ListenAddrs: []string{"127.0.0.1:0"},
   684  		RPCKey:      keyPath,
   685  		RPCCert:     certPath,
   686  	})
   687  	if err != nil {
   688  		t.Fatalf("server constructor error: %v", err)
   689  	}
   690  	address := "wss://" + server.listeners[0].Addr().String() + "/ws"
   691  
   692  	// Register routes before starting server.
   693  	// The 'ok' route returns an affirmative response.
   694  	type okresult struct {
   695  		OK bool `json:"ok"`
   696  	}
   697  	server.Route("ok", func(c Link, msg *msgjson.Message) *msgjson.Error {
   698  		resp, err := msgjson.NewResponse(msg.ID, &okresult{OK: true}, nil)
   699  		if err != nil {
   700  			return msgjson.NewError(500, "%v", err)
   701  		}
   702  		err = c.Send(resp)
   703  		if err != nil {
   704  			return msgjson.NewError(500, "%v", err)
   705  		}
   706  		return nil
   707  	})
   708  	// The 'banuser' route quarantines the user.
   709  	banChan := make(chan any)
   710  	server.Route("banuser", func(c Link, req *msgjson.Message) *msgjson.Error {
   711  		rpcErr := msgjson.NewError(msgjson.RPCQuarantineClient, "test quarantine")
   712  		msg, _ := msgjson.NewResponse(req.ID, nil, rpcErr)
   713  		err := c.Send(msg)
   714  		if err != nil {
   715  			t.Fatalf("banuser route send error: %v", err)
   716  		}
   717  		c.Banish()
   718  		banChan <- nil
   719  		return nil
   720  	})
   721  
   722  	ssw := dex.NewStartStopWaiter(server)
   723  	ssw.Start(testCtx)
   724  	defer func() {
   725  		ssw.Stop()
   726  		ssw.WaitForShutdown()
   727  	}()
   728  
   729  	// Get the SystemCertPool, continue with an empty pool on error
   730  	rootCAs, _ := x509.SystemCertPool()
   731  	if rootCAs == nil {
   732  		rootCAs = x509.NewCertPool()
   733  	}
   734  
   735  	// Read in the cert file
   736  	certs, err := os.ReadFile(certPath)
   737  	if err != nil {
   738  		t.Fatalf("Failed to append %q to RootCAs: %v", certPath, err)
   739  	}
   740  
   741  	// Append our cert to the system pool
   742  	if ok := rootCAs.AppendCertsFromPEM(certs); !ok {
   743  		t.Fatalf("No certs appended, using system certs only")
   744  	}
   745  
   746  	remoteClient, err := newTestBisonWallet(address, rootCAs)
   747  	if err != nil {
   748  		t.Fatalf("remoteClient constructor error: %v", err)
   749  	}
   750  
   751  	// A loop to grab responses from the server.
   752  	recv := make(chan any)
   753  	go func() {
   754  		for {
   755  			_, r, err := remoteClient.ReadMessage()
   756  			if err == nil {
   757  				recv <- r
   758  			} else {
   759  				recv <- err
   760  				break
   761  			}
   762  		}
   763  	}()
   764  
   765  	sendToDEX := func(route, msg string) error {
   766  		b, err := json.Marshal(makeReq(route, msg))
   767  		if err != nil {
   768  			t.Fatalf("error encoding %s request: %v", route, err)
   769  		}
   770  		err = remoteClient.WriteMessage(websocket.TextMessage, b)
   771  		return err
   772  	}
   773  
   774  	// Sleep for a couple of pongs to make sure the client doesn't disconnect.
   775  	time.Sleep(pongWait * 2)
   776  
   777  	// Positive path.
   778  	err = sendToDEX("ok", "{}")
   779  	if err != nil {
   780  		t.Fatalf("noresponse send error: %v", err)
   781  	}
   782  	b := readChannel(t, "ok", recv).([]byte)
   783  
   784  	msg, _ := msgjson.DecodeMessage(b)
   785  
   786  	ok := new(okresult)
   787  	err = msg.UnmarshalResult(ok)
   788  	if err != nil {
   789  		t.Fatalf("'ok' response unmarshal error: %v", err)
   790  	}
   791  	if !ok.OK {
   792  		t.Fatalf("ok.OK false")
   793  	}
   794  
   795  	// Ban the client using the special Error code.
   796  	err = sendToDEX("banuser", "{}")
   797  	if err != nil {
   798  		t.Fatalf("banuser send error: %v", err)
   799  	}
   800  	// Just for sequencing
   801  	readChannel(t, "noresponse", banChan)
   802  
   803  	msgB := readChannel(t, "banuser msg", recv).([]byte)
   804  	if !strings.Contains(string(msgB), "test quarantine") {
   805  		t.Fatalf("wrong ban message received: %s", string(msgB))
   806  	}
   807  
   808  	err = readChannel(t, "banuser err", recv).(error)
   809  	if err == nil {
   810  		t.Fatalf("no read error after ban")
   811  	}
   812  
   813  	// Try connecting, and make sure there is an error.
   814  	_, err = newTestBisonWallet(address, rootCAs)
   815  	if err == nil {
   816  		t.Fatalf("no websocket connection error after ban")
   817  	}
   818  	// Manually set the ban time.
   819  	server.banMtx.Lock()
   820  	if len(server.quarantine) != 1 {
   821  		t.Fatalf("unexpected number of quarantined IPs")
   822  	}
   823  	for ip := range server.quarantine {
   824  		server.quarantine[ip] = time.Now()
   825  	}
   826  	server.banMtx.Unlock()
   827  	// Now try again. Should connect.
   828  	conn, err := newTestBisonWallet(address, rootCAs)
   829  	if err != nil {
   830  		t.Fatalf("error connecting on expired ban")
   831  	}
   832  	var clientCount uint64
   833  	if !giveItASecond(func() bool {
   834  		clientCount = server.clientCount()
   835  		return clientCount == 1
   836  	}) {
   837  		t.Fatalf("server claiming %d clients. Expected 1", clientCount)
   838  	}
   839  	conn.Close()
   840  }
   841  
   842  func TestParseListeners(t *testing.T) {
   843  	ipv6wPort := "[fdc5:f621:d3b4:923f::]:80"
   844  	ipv6wZonePort := "[a:b:c:d::%123]:45"
   845  	// Invalid because capital letter O.
   846  	ipv6Invalid := "[1200:0000:AB00:1234:O000:2552:7777:1313]:1234"
   847  	ipv4wPort := "36.182.54.55:80"
   848  
   849  	ips := []string{
   850  		ipv6wPort,
   851  		ipv6wZonePort,
   852  		ipv4wPort,
   853  	}
   854  
   855  	out4, out6, hasWildcard, err := parseListeners(ips)
   856  	if err != nil {
   857  		t.Fatalf("error parsing listeners: %v", err)
   858  	}
   859  	if len(out4) != 1 {
   860  		t.Fatalf("expected 1 ipv4 addresses. found %d", len(out4))
   861  	}
   862  	if len(out6) != 2 {
   863  		t.Fatalf("expected 2 ipv6 addresses. found %d", len(out6))
   864  	}
   865  	if hasWildcard {
   866  		t.Fatal("hasWildcard true. should be false.")
   867  	}
   868  
   869  	// Port-only address goes in both.
   870  	ips = append(ips, ":1234")
   871  	out4, out6, hasWildcard, err = parseListeners(ips)
   872  	if err != nil {
   873  		t.Fatalf("error parsing listeners with wildcard: %v", err)
   874  	}
   875  	if len(out4) != 2 {
   876  		t.Fatalf("expected 2 ipv4 addresses. found %d", len(out4))
   877  	}
   878  	if len(out6) != 3 {
   879  		t.Fatalf("expected 3 ipv6 addresses. found %d", len(out6))
   880  	}
   881  	if !hasWildcard {
   882  		t.Fatal("hasWildcard false with port-only address")
   883  	}
   884  
   885  	// No port is invalid
   886  	ips = append(ips, "localhost")
   887  	_, _, _, err = parseListeners(ips)
   888  	if err == nil {
   889  		t.Fatal("no error when no IP specified")
   890  	}
   891  
   892  	// Pass invalid address
   893  	_, _, _, err = parseListeners([]string{ipv6Invalid})
   894  	if err == nil {
   895  		t.Fatal("no error with invalid address")
   896  	}
   897  }
   898  
   899  type tHTTPHandler struct {
   900  	count uint32
   901  }
   902  
   903  func (h *tHTTPHandler) ServeHTTP(http.ResponseWriter, *http.Request) {
   904  	atomic.AddUint32(&h.count, 1)
   905  }
   906  
   907  func TestHTTPRateLimiter(t *testing.T) {
   908  	tHandler := &tHTTPHandler{}
   909  	s := Server{dataEnabled: 1}
   910  
   911  	f := s.LimitRate(tHandler)
   912  	ip := "ip"
   913  	req := &http.Request{RemoteAddr: ip}
   914  	recorder := httptest.NewRecorder()
   915  	for i := 0; i < ipMaxBurstSize; i++ {
   916  		f.ServeHTTP(recorder, req)
   917  	}
   918  	time.Sleep(100 * time.Millisecond)
   919  	f.ServeHTTP(recorder, req)
   920  	successes := atomic.LoadUint32(&tHandler.count)
   921  	if successes != ipMaxBurstSize {
   922  		t.Fatalf("expected %d requests. got %d", ipMaxBurstSize, successes)
   923  	}
   924  	statusCode := recorder.Result().StatusCode
   925  	if statusCode != http.StatusTooManyRequests {
   926  		t.Fatalf("wrong status code. wanted %d, got %d", http.StatusTooManyRequests, statusCode)
   927  	}
   928  }
   929  
   930  func TestWSRateLimiter(t *testing.T) {
   931  	server := newServer()
   932  	var wg sync.WaitGroup
   933  	defer func() {
   934  		server.disconnectClients()
   935  		wg.Wait()
   936  	}()
   937  
   938  	handled := make(chan struct{}, 1)
   939  
   940  	server.Route(msgjson.FeeRateRoute, func(Link, *msgjson.Message) *msgjson.Error {
   941  		handled <- struct{}{}
   942  		return nil
   943  	})
   944  
   945  	server.Route(msgjson.OrderBookRoute, func(Link, *msgjson.Message) *msgjson.Error {
   946  		handled <- struct{}{}
   947  		return nil
   948  	})
   949  
   950  	conn := newWsStub()
   951  	conn.addChan()     // for <-conn.recv
   952  	conn.addNextChan() // for <-conn.nextRead, each time ReadMessage is called
   953  
   954  	wg.Add(1)
   955  	go func(conn *wsConnStub) {
   956  		defer wg.Done()
   957  		stubAddr := dex.NewIPKey("aabb:cc:ddee:ff::abc") // "abc" chopped by NewIPKey, "ff" chopped by PrefixV6
   958  		if stubAddr.IsUnspecified() {
   959  			t.Errorf("bad addr")
   960  			return
   961  		}
   962  		server.websocketHandler(testCtx, conn, stubAddr) // newWSLink -> Connect -> readloop will call handleMessage
   963  		close(conn.nextRead)                             // must be after read loop has quit (sends on nextRead)
   964  	}(conn)
   965  
   966  	<-conn.nextRead
   967  	go func() { // connected, so just keep receiving on the channel
   968  		for range conn.nextRead {
   969  		}
   970  	}()
   971  
   972  	waitResult := func() int {
   973  		t.Helper()
   974  		select {
   975  		case <-handled:
   976  			return 0
   977  		case resp := <-conn.recv: // test handlers only return resp with error (rate limit)
   978  			// t.Log("handler error message:", string(resp))
   979  			msg, err := msgjson.DecodeMessage(resp)
   980  			if err != nil {
   981  				t.Fatalf("failed to decode response message: %v", err)
   982  			}
   983  			payload, err := msg.Response()
   984  			if err != nil {
   985  				t.Fatalf("failed to decode response: %v", err)
   986  			}
   987  			if payload.Error == nil {
   988  				t.Fatalf("Expected rate limiting error, got none.")
   989  			}
   990  			if payload.Error.Code != msgjson.TooManyRequestsError {
   991  				t.Fatalf("Wanted code %d, got %d.", msgjson.TooManyRequestsError, payload.Error.Code)
   992  			}
   993  			if !strings.HasPrefix(payload.Error.Message, "too many requests") {
   994  				t.Fatalf("Wanted message with prefix %q, got %q.", "too many requests", payload.Error.Message)
   995  			}
   996  			return 1
   997  		case <-time.After(5 * time.Second):
   998  			t.Fatal("timeout")
   999  		}
  1000  		return 2
  1001  	}
  1002  
  1003  	// Other routes still work.
  1004  	sendToConn(t, conn, msgjson.FeeRateRoute, `{}`)
  1005  	if waitResult() != 0 {
  1006  		t.Fatalf("fee_rate request failed")
  1007  	}
  1008  
  1009  	// orderbook, which has 1 r/s rate limit, 100 burst
  1010  	sendToConn(t, conn, msgjson.OrderBookRoute, `{}`)
  1011  	if waitResult() != 0 {
  1012  		t.Fatalf("orderbook request failed")
  1013  	}
  1014  	sendToConn(t, conn, msgjson.OrderBookRoute, `{}`)
  1015  	if waitResult() != 0 { // tests burst > 1
  1016  		t.Fatalf("orderbook request failed")
  1017  	}
  1018  
  1019  	// New connection from different address.
  1020  	conn = newWsStub()
  1021  	conn.addChan()     // for <-conn.recv
  1022  	conn.addNextChan() // for <-conn.nextRead, each time ReadMessage is called
  1023  
  1024  	wg.Add(1)
  1025  	go func(conn *wsConnStub) {
  1026  		defer wg.Done()
  1027  		stubAddr := dex.NewIPKey("aabb:cc:ddee:11::") // same prefix, different subnet
  1028  		if stubAddr.IsUnspecified() {
  1029  			t.Errorf("bad addr")
  1030  			return
  1031  		}
  1032  		server.websocketHandler(testCtx, conn, stubAddr) // newWSLink -> Connect -> readloop will call handleMessage
  1033  		close(conn.nextRead)                             // must be after read loop has quit (sends on nextRead)
  1034  	}(conn)
  1035  
  1036  	<-conn.nextRead
  1037  	go func() { // connected, so just keep receiving on the channel
  1038  		for range conn.nextRead {
  1039  		}
  1040  	}()
  1041  }