github.com/okex/exchain@v1.8.0/libs/tendermint/rpc/jsonrpc/client/ws_client_test.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"net/http"
     7  	"net/http/httptest"
     8  	"sync"
     9  	"testing"
    10  	"time"
    11  
    12  	"github.com/gorilla/websocket"
    13  	"github.com/stretchr/testify/require"
    14  
    15  	"github.com/okex/exchain/libs/tendermint/libs/log"
    16  
    17  	types "github.com/okex/exchain/libs/tendermint/rpc/jsonrpc/types"
    18  )
    19  
    20  var wsCallTimeout = 5 * time.Second
    21  
    22  type myHandler struct {
    23  	closeConnAfterRead bool
    24  	mtx                sync.RWMutex
    25  }
    26  
    27  var upgrader = websocket.Upgrader{
    28  	ReadBufferSize:  1024,
    29  	WriteBufferSize: 1024,
    30  }
    31  
    32  func (h *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    33  	conn, err := upgrader.Upgrade(w, r, nil)
    34  	if err != nil {
    35  		panic(err)
    36  	}
    37  	defer conn.Close() // nolint: errcheck
    38  	for {
    39  		messageType, in, err := conn.ReadMessage()
    40  		if err != nil {
    41  			return
    42  		}
    43  
    44  		var req types.RPCRequest
    45  		err = json.Unmarshal(in, &req)
    46  		if err != nil {
    47  			panic(err)
    48  		}
    49  
    50  		h.mtx.RLock()
    51  		if h.closeConnAfterRead {
    52  			if err := conn.Close(); err != nil {
    53  				panic(err)
    54  			}
    55  		}
    56  		h.mtx.RUnlock()
    57  
    58  		res := json.RawMessage(`{}`)
    59  		emptyRespBytes, _ := json.Marshal(types.RPCResponse{Result: res, ID: req.ID})
    60  		if err := conn.WriteMessage(messageType, emptyRespBytes); err != nil {
    61  			return
    62  		}
    63  	}
    64  }
    65  
    66  func TestWSClientReconnectsAfterReadFailure(t *testing.T) {
    67  	var wg sync.WaitGroup
    68  
    69  	// start server
    70  	h := &myHandler{}
    71  	s := httptest.NewServer(h)
    72  	defer s.Close()
    73  
    74  	c := startClient(t, "//"+s.Listener.Addr().String())
    75  	defer c.Stop()
    76  
    77  	wg.Add(1)
    78  	go callWgDoneOnResult(t, c, &wg)
    79  
    80  	h.mtx.Lock()
    81  	h.closeConnAfterRead = true
    82  	h.mtx.Unlock()
    83  
    84  	// results in WS read error, no send retry because write succeeded
    85  	call(t, "a", c)
    86  
    87  	// expect to reconnect almost immediately
    88  	time.Sleep(10 * time.Millisecond)
    89  	h.mtx.Lock()
    90  	h.closeConnAfterRead = false
    91  	h.mtx.Unlock()
    92  
    93  	// should succeed
    94  	call(t, "b", c)
    95  
    96  	wg.Wait()
    97  }
    98  
    99  func TestWSClientReconnectsAfterWriteFailure(t *testing.T) {
   100  	var wg sync.WaitGroup
   101  
   102  	// start server
   103  	h := &myHandler{}
   104  	s := httptest.NewServer(h)
   105  
   106  	c := startClient(t, "//"+s.Listener.Addr().String())
   107  	defer c.Stop()
   108  
   109  	wg.Add(2)
   110  	go callWgDoneOnResult(t, c, &wg)
   111  
   112  	// hacky way to abort the connection before write
   113  	if err := c.conn.Close(); err != nil {
   114  		t.Error(err)
   115  	}
   116  
   117  	// results in WS write error, the client should resend on reconnect
   118  	call(t, "a", c)
   119  
   120  	// expect to reconnect almost immediately
   121  	time.Sleep(10 * time.Millisecond)
   122  
   123  	// should succeed
   124  	call(t, "b", c)
   125  
   126  	wg.Wait()
   127  }
   128  
   129  func TestWSClientReconnectFailure(t *testing.T) {
   130  	// start server
   131  	h := &myHandler{}
   132  	s := httptest.NewServer(h)
   133  
   134  	c := startClient(t, "//"+s.Listener.Addr().String())
   135  	defer c.Stop()
   136  
   137  	go func() {
   138  		for {
   139  			select {
   140  			case <-c.ResponsesCh:
   141  			case <-c.Quit():
   142  				return
   143  			}
   144  		}
   145  	}()
   146  
   147  	// hacky way to abort the connection before write
   148  	if err := c.conn.Close(); err != nil {
   149  		t.Error(err)
   150  	}
   151  	s.Close()
   152  
   153  	// results in WS write error
   154  	// provide timeout to avoid blocking
   155  	ctx, cancel := context.WithTimeout(context.Background(), wsCallTimeout)
   156  	defer cancel()
   157  	if err := c.Call(ctx, "a", make(map[string]interface{})); err != nil {
   158  		t.Error(err)
   159  	}
   160  
   161  	// expect to reconnect almost immediately
   162  	time.Sleep(10 * time.Millisecond)
   163  
   164  	done := make(chan struct{})
   165  	go func() {
   166  		// client should block on this
   167  		call(t, "b", c)
   168  		close(done)
   169  	}()
   170  
   171  	// test that client blocks on the second send
   172  	select {
   173  	case <-done:
   174  		t.Fatal("client should block on calling 'b' during reconnect")
   175  	case <-time.After(5 * time.Second):
   176  		t.Log("All good")
   177  	}
   178  }
   179  
   180  func TestNotBlockingOnStop(t *testing.T) {
   181  	timeout := 2 * time.Second
   182  	s := httptest.NewServer(&myHandler{})
   183  	c := startClient(t, "//"+s.Listener.Addr().String())
   184  	c.Call(context.Background(), "a", make(map[string]interface{}))
   185  	// Let the readRoutine get around to blocking
   186  	time.Sleep(time.Second)
   187  	passCh := make(chan struct{})
   188  	go func() {
   189  		// Unless we have a non-blocking write to ResponsesCh from readRoutine
   190  		// this blocks forever ont the waitgroup
   191  		c.Stop()
   192  		passCh <- struct{}{}
   193  	}()
   194  	select {
   195  	case <-passCh:
   196  		// Pass
   197  	case <-time.After(timeout):
   198  		t.Fatalf("WSClient did failed to stop within %v seconds - is one of the read/write routines blocking?",
   199  			timeout.Seconds())
   200  	}
   201  }
   202  
   203  func startClient(t *testing.T, addr string) *WSClient {
   204  	c, err := NewWS(addr, "/websocket")
   205  	require.Nil(t, err)
   206  	err = c.Start()
   207  	require.Nil(t, err)
   208  	c.SetLogger(log.TestingLogger())
   209  	return c
   210  }
   211  
   212  func call(t *testing.T, method string, c *WSClient) {
   213  	err := c.Call(context.Background(), method, make(map[string]interface{}))
   214  	require.NoError(t, err)
   215  }
   216  
   217  func callWgDoneOnResult(t *testing.T, c *WSClient, wg *sync.WaitGroup) {
   218  	for {
   219  		select {
   220  		case resp := <-c.ResponsesCh:
   221  			if resp.Error != nil {
   222  				t.Errorf("unexpected error: %v", resp.Error)
   223  				return
   224  			}
   225  			if resp.Result != nil {
   226  				wg.Done()
   227  			}
   228  		case <-c.Quit():
   229  			return
   230  		}
   231  	}
   232  }