github.com/ari-anchor/sei-tendermint@v0.0.0-20230519144642-dc826b7b56bb/rpc/jsonrpc/client/ws_client_test.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/fortytw2/leaktest"
    13  	"github.com/gorilla/websocket"
    14  	"github.com/stretchr/testify/assert"
    15  	"github.com/stretchr/testify/require"
    16  
    17  	rpctypes "github.com/ari-anchor/sei-tendermint/rpc/jsonrpc/types"
    18  )
    19  
    20  const wsCallTimeout = 5 * time.Second
    21  
    22  type myTestHandler struct {
    23  	closeConnAfterRead bool
    24  	mtx                sync.RWMutex
    25  	t                  *testing.T
    26  }
    27  
    28  var upgrader = websocket.Upgrader{
    29  	ReadBufferSize:  1024,
    30  	WriteBufferSize: 1024,
    31  }
    32  
    33  func (h *myTestHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    34  	conn, err := upgrader.Upgrade(w, r, nil)
    35  	require.NoError(h.t, err)
    36  
    37  	defer conn.Close()
    38  	for {
    39  		messageType, in, err := conn.ReadMessage()
    40  		if err != nil {
    41  			return
    42  		}
    43  
    44  		var req rpctypes.RPCRequest
    45  		err = json.Unmarshal(in, &req)
    46  		require.NoError(h.t, err)
    47  
    48  		func() {
    49  			h.mtx.RLock()
    50  			defer h.mtx.RUnlock()
    51  
    52  			if h.closeConnAfterRead {
    53  				require.NoError(h.t, conn.Close())
    54  			}
    55  		}()
    56  
    57  		res := json.RawMessage(`{}`)
    58  
    59  		emptyRespBytes, err := json.Marshal(req.MakeResponse(res))
    60  		require.NoError(h.t, err)
    61  		if err := conn.WriteMessage(messageType, emptyRespBytes); err != nil {
    62  			return
    63  		}
    64  	}
    65  }
    66  
    67  func TestWSClientReconnectsAfterReadFailure(t *testing.T) {
    68  	t.Cleanup(leaktest.Check(t))
    69  
    70  	// start server
    71  	h := &myTestHandler{t: t}
    72  	s := httptest.NewServer(h)
    73  	defer s.Close()
    74  
    75  	ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
    76  	defer cancel()
    77  
    78  	c := startClient(ctx, t, "//"+s.Listener.Addr().String())
    79  
    80  	go handleResponses(ctx, t, c)
    81  
    82  	h.mtx.Lock()
    83  	h.closeConnAfterRead = true
    84  	h.mtx.Unlock()
    85  
    86  	// results in WS read error, no send retry because write succeeded
    87  	call(ctx, t, "a", c)
    88  
    89  	// expect to reconnect almost immediately
    90  	time.Sleep(10 * time.Millisecond)
    91  	h.mtx.Lock()
    92  	h.closeConnAfterRead = false
    93  	h.mtx.Unlock()
    94  
    95  	// should succeed
    96  	call(ctx, t, "b", c)
    97  }
    98  
    99  func TestWSClientReconnectsAfterWriteFailure(t *testing.T) {
   100  	t.Cleanup(leaktest.Check(t))
   101  
   102  	// start server
   103  	h := &myTestHandler{t: t}
   104  	s := httptest.NewServer(h)
   105  	defer s.Close()
   106  
   107  	ctx, cancel := context.WithCancel(context.Background())
   108  	defer cancel()
   109  
   110  	c := startClient(ctx, t, "//"+s.Listener.Addr().String())
   111  
   112  	go handleResponses(ctx, t, c)
   113  
   114  	// hacky way to abort the connection before write
   115  	if err := c.conn.Close(); err != nil {
   116  		t.Error(err)
   117  	}
   118  
   119  	// results in WS write error, the client should resend on reconnect
   120  	call(ctx, t, "a", c)
   121  
   122  	// expect to reconnect almost immediately
   123  	time.Sleep(10 * time.Millisecond)
   124  
   125  	// should succeed
   126  	call(ctx, t, "b", c)
   127  }
   128  
   129  func TestWSClientReconnectFailure(t *testing.T) {
   130  	t.Cleanup(leaktest.Check(t))
   131  
   132  	// start server
   133  	h := &myTestHandler{t: t}
   134  	s := httptest.NewServer(h)
   135  
   136  	ctx, cancel := context.WithCancel(context.Background())
   137  	defer cancel()
   138  
   139  	c := startClient(ctx, t, "//"+s.Listener.Addr().String())
   140  
   141  	go func() {
   142  		for {
   143  			select {
   144  			case <-c.ResponsesCh:
   145  			case <-ctx.Done():
   146  				return
   147  			}
   148  		}
   149  	}()
   150  
   151  	// hacky way to abort the connection before write
   152  	if err := c.conn.Close(); err != nil {
   153  		t.Error(err)
   154  	}
   155  	s.Close()
   156  
   157  	// results in WS write error
   158  	// provide timeout to avoid blocking
   159  	cctx, cancel := context.WithTimeout(ctx, wsCallTimeout)
   160  	defer cancel()
   161  	if err := c.Call(cctx, "a", make(map[string]interface{})); err != nil {
   162  		t.Error(err)
   163  	}
   164  
   165  	// expect to reconnect almost immediately
   166  	time.Sleep(10 * time.Millisecond)
   167  
   168  	done := make(chan struct{})
   169  	go func() {
   170  		// client should block on this
   171  		call(ctx, t, "b", c)
   172  		close(done)
   173  	}()
   174  
   175  	// test that client blocks on the second send
   176  	select {
   177  	case <-done:
   178  		t.Fatal("client should block on calling 'b' during reconnect")
   179  	case <-time.After(5 * time.Second):
   180  		t.Log("All good")
   181  	}
   182  }
   183  
   184  func TestNotBlockingOnStop(t *testing.T) {
   185  	t.Cleanup(leaktest.Check(t))
   186  
   187  	s := httptest.NewServer(&myTestHandler{t: t})
   188  	defer s.Close()
   189  	ctx, cancel := context.WithCancel(context.Background())
   190  	defer cancel()
   191  
   192  	c := startClient(ctx, t, "//"+s.Listener.Addr().String())
   193  	require.NoError(t, c.Call(ctx, "a", make(map[string]interface{})))
   194  
   195  	time.Sleep(200 * time.Millisecond) // give service routines time to start ⚠️
   196  	done := make(chan struct{})
   197  	go func() {
   198  		cancel()
   199  		if assert.NoError(t, c.Stop()) {
   200  			close(done)
   201  		}
   202  	}()
   203  	select {
   204  	case <-done:
   205  		t.Log("Stopped client successfully")
   206  	case <-time.After(2 * time.Second):
   207  		t.Fatal("Timed out waiting for client to stop")
   208  	}
   209  }
   210  
   211  func startClient(ctx context.Context, t *testing.T, addr string) *WSClient {
   212  	t.Helper()
   213  
   214  	t.Cleanup(leaktest.Check(t))
   215  
   216  	c, err := NewWS(addr, "/websocket")
   217  	require.NoError(t, err)
   218  	require.NoError(t, c.Start(ctx))
   219  	return c
   220  }
   221  
   222  func call(ctx context.Context, t *testing.T, method string, c *WSClient) {
   223  	t.Helper()
   224  
   225  	err := c.Call(ctx, method, make(map[string]interface{}))
   226  	if ctx.Err() == nil {
   227  		require.NoError(t, err)
   228  	}
   229  }
   230  
   231  func handleResponses(ctx context.Context, t *testing.T, c *WSClient) {
   232  	t.Helper()
   233  
   234  	for {
   235  		select {
   236  		case resp := <-c.ResponsesCh:
   237  			if resp.Error != nil {
   238  				t.Errorf("unexpected error: %v", resp.Error)
   239  				return
   240  			}
   241  			if resp.Result != nil {
   242  				return
   243  			}
   244  		case <-ctx.Done():
   245  			return
   246  		}
   247  	}
   248  }