decred.org/dcrdex@v1.0.5/client/comms/wsconn_test.go (about)

     1  package comms
     2  
     3  import (
     4  	"bytes"
     5  	"context"
     6  	"crypto/elliptic"
     7  	"encoding/hex"
     8  	"errors"
     9  	"fmt"
    10  	"net"
    11  	"net/http"
    12  	"os"
    13  	"runtime"
    14  	"sync"
    15  	"sync/atomic"
    16  	"testing"
    17  	"time"
    18  
    19  	"decred.org/dcrdex/dex"
    20  	"decred.org/dcrdex/dex/msgjson"
    21  	"github.com/decred/dcrd/certgen"
    22  	"github.com/gorilla/websocket"
    23  )
    24  
    25  var tLogger = dex.StdOutLogger("conn_TEST", dex.LevelTrace)
    26  
    27  func makeRequest(id uint64, route string, msg any) *msgjson.Message {
    28  	req, _ := msgjson.NewRequest(id, route, msg)
    29  	return req
    30  }
    31  
    32  // genCertPair generates a key/cert pair to the paths provided.
    33  func genCertPair(certFile, keyFile string, altDNSNames []string) error {
    34  	tLogger.Infof("Generating TLS certificates...")
    35  
    36  	org := "dcrdex autogenerated cert"
    37  	validUntil := time.Now().Add(10 * 365 * 24 * time.Hour)
    38  	cert, key, err := certgen.NewTLSCertPair(elliptic.P521(), org,
    39  		validUntil, altDNSNames)
    40  	if err != nil {
    41  		return err
    42  	}
    43  
    44  	// Write cert and key files.
    45  	if err = os.WriteFile(certFile, cert, 0644); err != nil {
    46  		return err
    47  	}
    48  	if err = os.WriteFile(keyFile, key, 0600); err != nil {
    49  		os.Remove(certFile)
    50  		return err
    51  	}
    52  
    53  	tLogger.Infof("Done generating TLS certificates")
    54  	return nil
    55  }
    56  
    57  func TestWsConn(t *testing.T) {
    58  	// Must wait for goroutines, especially the ones that capture t.
    59  	var wg sync.WaitGroup
    60  	defer wg.Wait()
    61  
    62  	upgrader := websocket.Upgrader{}
    63  
    64  	pingCh := make(chan struct{})
    65  	readPumpCh := make(chan any)
    66  	writePumpCh := make(chan *msgjson.Message)
    67  	ctx, cancel := context.WithCancel(context.Background())
    68  	defer cancel()
    69  
    70  	type conn struct {
    71  		sync.WaitGroup
    72  		*websocket.Conn
    73  	}
    74  	var clientMtx sync.Mutex
    75  	clients := make(map[uint64]*conn)
    76  
    77  	// server.Shutdown does not wait for hijacked connections, and pong handler
    78  	// uses t.Logf.
    79  	defer func() {
    80  		clientMtx.Lock()
    81  		for id, h := range clients {
    82  			h.Close()
    83  			h.Wait()
    84  			delete(clients, id)
    85  		}
    86  		clientMtx.Unlock()
    87  	}()
    88  
    89  	var id uint64
    90  	// server's "/ws" handler
    91  	handler := func(w http.ResponseWriter, r *http.Request) {
    92  		t.Helper()
    93  		id := atomic.AddUint64(&id, 1) // shadow id
    94  		hCtx, hCancel := context.WithCancel(ctx)
    95  
    96  		c, err := upgrader.Upgrade(w, r, nil)
    97  		if err != nil {
    98  			t.Errorf("unable to upgrade http connection: %s", err)
    99  		}
   100  
   101  		ch := &conn{Conn: c}
   102  		clientMtx.Lock()
   103  		clients[id] = ch
   104  		clientMtx.Unlock()
   105  
   106  		c.SetPongHandler(func(string) error {
   107  			t.Logf("handler #%d: pong received", id)
   108  			return nil
   109  		})
   110  
   111  		ch.Add(1)
   112  		go func() {
   113  			defer ch.Done()
   114  			for {
   115  				select {
   116  				case <-pingCh:
   117  					err := c.WriteControl(websocket.PingMessage, []byte{},
   118  						time.Now().Add(writeWait))
   119  					if err != nil {
   120  						if hCtx.Err() == nil {
   121  							// Only a failure if the server isn't shutting down.
   122  							t.Errorf("handler #%d: ping error: %v", id, err)
   123  						}
   124  						return
   125  					}
   126  
   127  					t.Logf("handler #%d: ping sent", id)
   128  
   129  				case msg := <-readPumpCh:
   130  					err := c.WriteJSON(msg)
   131  					if err != nil {
   132  						t.Errorf("handler #%d: write error: %v", id, err)
   133  						return
   134  					}
   135  
   136  				case <-hCtx.Done():
   137  					return
   138  				}
   139  			}
   140  		}()
   141  
   142  		ch.Add(1)
   143  		go func() {
   144  			defer ch.Done()
   145  			for {
   146  				mType, message, err := c.ReadMessage()
   147  				if err != nil {
   148  					hCancel()
   149  					c.Close()
   150  
   151  					// If the context has been canceled, don't do anything.
   152  					if hCtx.Err() != nil {
   153  						return
   154  					}
   155  
   156  					if websocket.IsCloseError(err, websocket.CloseNormalClosure) {
   157  						// Terminate on a normal close message.
   158  						return
   159  					}
   160  
   161  					t.Errorf("handler #%d: read error: %v\n", id, err)
   162  					return
   163  				}
   164  
   165  				if mType == websocket.TextMessage {
   166  					msg, err := msgjson.DecodeMessage(message)
   167  					if err != nil {
   168  						t.Errorf("handler #%d: decode error: %v", id, err)
   169  						continue // Don't hang up.
   170  					}
   171  
   172  					writePumpCh <- msg
   173  				}
   174  			}
   175  		}()
   176  	}
   177  
   178  	certFile, err := os.CreateTemp("", "certfile")
   179  	if err != nil {
   180  		t.Fatalf("unable to create temp certfile: %s", err)
   181  	}
   182  	certFile.Close()
   183  	defer os.Remove(certFile.Name())
   184  
   185  	keyFile, err := os.CreateTemp("", "keyfile")
   186  	if err != nil {
   187  		t.Fatalf("unable to create temp keyfile: %s", err)
   188  	}
   189  	keyFile.Close()
   190  	defer os.Remove(keyFile.Name())
   191  
   192  	err = genCertPair(certFile.Name(), keyFile.Name(), nil)
   193  	if err != nil {
   194  		t.Fatal(err)
   195  	}
   196  
   197  	certB, err := os.ReadFile(certFile.Name())
   198  	if err != nil {
   199  		t.Fatalf("file reading error: %v", err)
   200  	}
   201  
   202  	host := "127.0.0.1:0"
   203  	mux := http.NewServeMux()
   204  	mux.HandleFunc("/ws", handler)
   205  
   206  	// http server for the connect and upgrade
   207  	server := &http.Server{
   208  		WriteTimeout: time.Second * 10,
   209  		ReadTimeout:  time.Second * 10,
   210  		Addr:         host,
   211  		Handler:      mux,
   212  	}
   213  	defer server.Shutdown(context.Background())
   214  
   215  	wg.Add(1)
   216  	serverReady := make(chan error, 1)
   217  	go func() {
   218  		defer wg.Done()
   219  
   220  		ln, err := net.Listen("tcp", server.Addr)
   221  		if err != nil {
   222  			serverReady <- err
   223  			return
   224  		}
   225  		defer ln.Close()
   226  		//log.Info(ln.Addr().(*net.TCPAddr).Port)
   227  		host = ln.Addr().String()
   228  		serverReady <- nil // after setting host
   229  
   230  		err = server.ServeTLS(ln, certFile.Name(), keyFile.Name())
   231  		if err != nil {
   232  			fmt.Println(err)
   233  		}
   234  	}()
   235  
   236  	// wait for server to start listening before connecting
   237  	err = <-serverReady
   238  	if err != nil {
   239  		t.Fatal(err)
   240  	}
   241  
   242  	const pingWait = 500 * time.Millisecond
   243  	setupWsConn := func(cert []byte) (*wsConn, error) {
   244  		cfg := &WsCfg{
   245  			URL:      "wss://" + host + "/ws",
   246  			PingWait: pingWait,
   247  			Cert:     cert,
   248  			Logger:   tLogger,
   249  		}
   250  		conn, err := NewWsConn(cfg)
   251  		if err != nil {
   252  			return nil, err
   253  		}
   254  		return conn.(*wsConn), nil
   255  	}
   256  
   257  	// test no cert error
   258  	noCertConn, err := setupWsConn(nil)
   259  	if err != nil {
   260  		t.Fatal(err)
   261  	}
   262  	noCertConnMaster := dex.NewConnectionMaster(noCertConn)
   263  	err = noCertConnMaster.Connect(ctx)
   264  	noCertConnMaster.Disconnect()
   265  	if err == nil || !errors.Is(err, ErrCertRequired) {
   266  		t.Fatalf("failed to get ErrCertRequired for no cert connection, got %v", err)
   267  	}
   268  
   269  	// test invalid cert error
   270  	_, err = setupWsConn([]byte("invalid cert"))
   271  	if err == nil || !errors.Is(err, ErrInvalidCert) {
   272  		t.Fatalf("failed to get ErrInvalidCert for invalid cert connection, got %v", err)
   273  	}
   274  
   275  	// connect with cert
   276  	wsc, err := setupWsConn(certB)
   277  	if err != nil {
   278  		t.Fatal(err)
   279  	}
   280  	waiter := dex.NewConnectionMaster(wsc)
   281  	err = waiter.Connect(ctx)
   282  	if err != nil {
   283  		t.Fatalf("Connect: %v", err)
   284  	}
   285  
   286  	reconnectAndPing := func() {
   287  		// Drop the connection and force a reconnect by waiting longer than the
   288  		// read deadline (the ping wait), plus a bit extra to allow the timeout
   289  		// to flip off the connection and queue a reconnect.
   290  		time.Sleep(pingWait * 3 / 2)
   291  		runtime.Gosched()
   292  
   293  		// Wait for a reconnection.
   294  		for wsc.IsDown() {
   295  			time.Sleep(time.Millisecond * 10)
   296  			continue
   297  		}
   298  
   299  		// Send a ping.
   300  		pingCh <- struct{}{}
   301  	}
   302  
   303  	orderid, _ := hex.DecodeString("ceb09afa675cee31c0f858b94c81bd1a4c2af8c5947d13e544eef772381f2c8d")
   304  	matchid, _ := hex.DecodeString("7c6b44735e303585d644c713fe0e95897e7e8ba2b9bba98d6d61b70006d3d58c")
   305  	match := &msgjson.Match{
   306  		OrderID:  orderid,
   307  		MatchID:  matchid,
   308  		Quantity: 20,
   309  		Rate:     2,
   310  		Address:  "DsiNAJCd2sSazZRU9ViDD334DaLgU1Kse3P",
   311  	}
   312  
   313  	// Ensure a malformed message to the client does not terminate
   314  	// the connection.
   315  	readPumpCh <- []byte("{notjson")
   316  
   317  	// Send a message to the client.
   318  	sent := makeRequest(1, msgjson.MatchRoute, match)
   319  	readPumpCh <- sent
   320  
   321  	// Fetch the read source.
   322  	readSource := wsc.MessageSource()
   323  	if readSource == nil {
   324  		t.Fatal("expected a non-nil read source")
   325  	}
   326  
   327  	// Read the message received by the client.
   328  	received := <-readSource
   329  
   330  	// Ensure the received message equal to the sent message.
   331  	if received.Type != sent.Type {
   332  		t.Fatalf("expected %v type, got %v", sent.Type, received.Type)
   333  	}
   334  
   335  	if received.Route != sent.Route {
   336  		t.Fatalf("expected %v route, got %v", sent.Route, received.Route)
   337  	}
   338  
   339  	if received.ID != sent.ID {
   340  		t.Fatalf("expected %v id, got %v", sent.ID, received.ID)
   341  	}
   342  
   343  	if !bytes.Equal(received.Payload, sent.Payload) {
   344  		t.Fatal("sent and received payload mismatch")
   345  	}
   346  
   347  	reconnectAndPing()
   348  
   349  	coinID := []byte{
   350  		0xc3, 0x16, 0x10, 0x33, 0xde, 0x09, 0x6f, 0xd7, 0x4d, 0x90, 0x51, 0xff,
   351  		0x0b, 0xd9, 0x9e, 0x35, 0x9d, 0xe3, 0x50, 0x80, 0xa3, 0x51, 0x10, 0x81,
   352  		0xed, 0x03, 0x5f, 0x54, 0x1b, 0x85, 0x0d, 0x43, 0x00, 0x00, 0x00, 0x0a,
   353  	}
   354  
   355  	contract, _ := hex.DecodeString("caf8d277f80f71e4")
   356  	init := &msgjson.Init{
   357  		OrderID:  orderid,
   358  		MatchID:  matchid,
   359  		CoinID:   coinID,
   360  		Contract: contract,
   361  	}
   362  
   363  	// Send a message from the client.
   364  	mId := wsc.NextID()
   365  	sent = makeRequest(mId, msgjson.InitRoute, init)
   366  	handlerRun := false
   367  	err = wsc.Request(sent, func(*msgjson.Message) {
   368  		handlerRun = true
   369  	})
   370  	if err != nil {
   371  		t.Fatalf("unexpected error: %v", err)
   372  	}
   373  
   374  	// Read the message received by the server.
   375  	received = <-writePumpCh
   376  
   377  	// Ensure the received message equal to the sent message.
   378  	if received.Type != sent.Type {
   379  		t.Fatalf("expected %v type, got %v", sent.Type, received.Type)
   380  	}
   381  
   382  	if received.Route != sent.Route {
   383  		t.Fatalf("expected %v route, got %v", sent.Route, received.Route)
   384  	}
   385  
   386  	if received.ID != sent.ID {
   387  		t.Fatalf("expected %v id, got %v", sent.ID, received.ID)
   388  	}
   389  
   390  	if !bytes.Equal(received.Payload, sent.Payload) {
   391  		t.Fatal("sent and received payload mismatch")
   392  	}
   393  
   394  	// Ensure the next id is as expected.
   395  	next := wsc.NextID()
   396  	if next != 2 {
   397  		t.Fatalf("expected next id to be %d, got %d", 2, next)
   398  	}
   399  
   400  	// Ensure the request got logged, also unregister the response handler.
   401  	hndlr := wsc.respHandler(mId)
   402  	if hndlr == nil {
   403  		t.Fatalf("no handler found")
   404  	}
   405  	hndlr.f(nil)
   406  	if !handlerRun {
   407  		t.Fatalf("wrong handler retrieved")
   408  	}
   409  
   410  	// Ensure the response handler is unlogged.
   411  	hndlr = wsc.respHandler(mId)
   412  	if hndlr != nil {
   413  		t.Fatal("found a response handler for an unlogged request id")
   414  	}
   415  
   416  	pingCh <- struct{}{}
   417  
   418  	// Ensure malformed request data (a send failure) does not leave a
   419  	// registered response handler or kill the connection.
   420  	sent.ID = wsc.NextID()
   421  	sent.Payload = []byte("{notjson")
   422  	err = wsc.Request(sent, func(*msgjson.Message) {})
   423  	if err == nil {
   424  		t.Fatalf("expected error with malformed request payload")
   425  	}
   426  
   427  	// Ensure the response handler is unregistered.
   428  	if wsc.respHandler(mId) != nil {
   429  		t.Fatal("response handler was still registered")
   430  	}
   431  
   432  	// New request to test expiration.
   433  	mId = next
   434  	sent = makeRequest(mId, msgjson.InitRoute, init)
   435  	expiring := make(chan struct{}, 1)
   436  	expTime := 50 * time.Millisecond // way shorter than pingWait
   437  	err = wsc.RequestWithTimeout(sent, func(*msgjson.Message) {}, expTime, func() {
   438  		expiring <- struct{}{}
   439  	})
   440  	if err != nil {
   441  		t.Fatalf("unexpected error: %v", err)
   442  	}
   443  	<-writePumpCh
   444  
   445  	pingCh <- struct{}{}
   446  
   447  	// Yield to the comms goroutine in case this machine is poor.
   448  	runtime.Gosched()
   449  	select {
   450  	case <-expiring:
   451  	case <-time.NewTimer(time.Second).C: // >> expTime
   452  		t.Fatalf("didn't expire") // conn will be dead by this time without pings
   453  	}
   454  
   455  	// New request to abort on conn shutdown.
   456  	sent = makeRequest(wsc.NextID(), msgjson.InitRoute, init)
   457  	expiring = make(chan struct{}, 1)
   458  	expTime = 20 * time.Second                  // we're going to cancel first
   459  	beforeExpire := time.After(2 * time.Second) // enough time for shutdown to call expire func
   460  	err = wsc.RequestWithTimeout(sent, func(*msgjson.Message) {}, expTime, func() {
   461  		expiring <- struct{}{}
   462  	})
   463  	if err != nil {
   464  		t.Fatalf("unexpected error: %v", err)
   465  	}
   466  	<-writePumpCh
   467  
   468  	pingCh <- struct{}{}
   469  
   470  	// Shutdown/Disconnect before expire.
   471  	time.Sleep(50 * time.Millisecond) // let pings and pongs flush, but it's not a problem if they bomb
   472  	waiter.Disconnect()
   473  
   474  	select {
   475  	case <-beforeExpire: // much shorter than req timeout
   476  		t.Error("expire func not called on conn shutdown")
   477  	case <-expiring: // means aborted if triggered before timeout
   478  	}
   479  
   480  	select {
   481  	case _, ok := <-readSource:
   482  		if ok {
   483  			t.Error("read source should have been closed")
   484  		}
   485  	default:
   486  		t.Error("read source should have been closed")
   487  	}
   488  }