github.com/evdatsion/aphelion-dpos-bft@v0.32.1/rpc/lib/client/ws_client_test.go (about)

     1  package rpcclient
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"net"
     7  	"net/http"
     8  	"net/http/httptest"
     9  	"sync"
    10  	"testing"
    11  	"time"
    12  
    13  	"github.com/gorilla/websocket"
    14  	"github.com/stretchr/testify/require"
    15  	"github.com/evdatsion/aphelion-dpos-bft/libs/log"
    16  
    17  	types "github.com/evdatsion/aphelion-dpos-bft/rpc/lib/types"
    18  )
    19  
    20  var wsCallTimeout = 5 * time.Second
    21  
    22  type myHandler struct {
    23  	closeConnAfterRead bool
    24  	mtx                sync.RWMutex
    25  }
    26  
    27  var upgrader = websocket.Upgrader{
    28  	ReadBufferSize:  1024,
    29  	WriteBufferSize: 1024,
    30  }
    31  
    32  func (h *myHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
    33  	conn, err := upgrader.Upgrade(w, r, nil)
    34  	if err != nil {
    35  		panic(err)
    36  	}
    37  	defer conn.Close() // nolint: errcheck
    38  	for {
    39  		messageType, _, err := conn.ReadMessage()
    40  		if err != nil {
    41  			return
    42  		}
    43  
    44  		h.mtx.RLock()
    45  		if h.closeConnAfterRead {
    46  			if err := conn.Close(); err != nil {
    47  				panic(err)
    48  			}
    49  		}
    50  		h.mtx.RUnlock()
    51  
    52  		res := json.RawMessage(`{}`)
    53  		emptyRespBytes, _ := json.Marshal(types.RPCResponse{Result: res})
    54  		if err := conn.WriteMessage(messageType, emptyRespBytes); err != nil {
    55  			return
    56  		}
    57  	}
    58  }
    59  
    60  func TestWSClientReconnectsAfterReadFailure(t *testing.T) {
    61  	var wg sync.WaitGroup
    62  
    63  	// start server
    64  	h := &myHandler{}
    65  	s := httptest.NewServer(h)
    66  	defer s.Close()
    67  
    68  	c := startClient(t, s.Listener.Addr())
    69  	defer c.Stop()
    70  
    71  	wg.Add(1)
    72  	go callWgDoneOnResult(t, c, &wg)
    73  
    74  	h.mtx.Lock()
    75  	h.closeConnAfterRead = true
    76  	h.mtx.Unlock()
    77  
    78  	// results in WS read error, no send retry because write succeeded
    79  	call(t, "a", c)
    80  
    81  	// expect to reconnect almost immediately
    82  	time.Sleep(10 * time.Millisecond)
    83  	h.mtx.Lock()
    84  	h.closeConnAfterRead = false
    85  	h.mtx.Unlock()
    86  
    87  	// should succeed
    88  	call(t, "b", c)
    89  
    90  	wg.Wait()
    91  }
    92  
    93  func TestWSClientReconnectsAfterWriteFailure(t *testing.T) {
    94  	var wg sync.WaitGroup
    95  
    96  	// start server
    97  	h := &myHandler{}
    98  	s := httptest.NewServer(h)
    99  
   100  	c := startClient(t, s.Listener.Addr())
   101  	defer c.Stop()
   102  
   103  	wg.Add(2)
   104  	go callWgDoneOnResult(t, c, &wg)
   105  
   106  	// hacky way to abort the connection before write
   107  	if err := c.conn.Close(); err != nil {
   108  		t.Error(err)
   109  	}
   110  
   111  	// results in WS write error, the client should resend on reconnect
   112  	call(t, "a", c)
   113  
   114  	// expect to reconnect almost immediately
   115  	time.Sleep(10 * time.Millisecond)
   116  
   117  	// should succeed
   118  	call(t, "b", c)
   119  
   120  	wg.Wait()
   121  }
   122  
   123  func TestWSClientReconnectFailure(t *testing.T) {
   124  	// start server
   125  	h := &myHandler{}
   126  	s := httptest.NewServer(h)
   127  
   128  	c := startClient(t, s.Listener.Addr())
   129  	defer c.Stop()
   130  
   131  	go func() {
   132  		for {
   133  			select {
   134  			case <-c.ResponsesCh:
   135  			case <-c.Quit():
   136  				return
   137  			}
   138  		}
   139  	}()
   140  
   141  	// hacky way to abort the connection before write
   142  	if err := c.conn.Close(); err != nil {
   143  		t.Error(err)
   144  	}
   145  	s.Close()
   146  
   147  	// results in WS write error
   148  	// provide timeout to avoid blocking
   149  	ctx, cancel := context.WithTimeout(context.Background(), wsCallTimeout)
   150  	defer cancel()
   151  	if err := c.Call(ctx, "a", make(map[string]interface{})); err != nil {
   152  		t.Error(err)
   153  	}
   154  
   155  	// expect to reconnect almost immediately
   156  	time.Sleep(10 * time.Millisecond)
   157  
   158  	done := make(chan struct{})
   159  	go func() {
   160  		// client should block on this
   161  		call(t, "b", c)
   162  		close(done)
   163  	}()
   164  
   165  	// test that client blocks on the second send
   166  	select {
   167  	case <-done:
   168  		t.Fatal("client should block on calling 'b' during reconnect")
   169  	case <-time.After(5 * time.Second):
   170  		t.Log("All good")
   171  	}
   172  }
   173  
   174  func TestNotBlockingOnStop(t *testing.T) {
   175  	timeout := 2 * time.Second
   176  	s := httptest.NewServer(&myHandler{})
   177  	c := startClient(t, s.Listener.Addr())
   178  	c.Call(context.Background(), "a", make(map[string]interface{}))
   179  	// Let the readRoutine get around to blocking
   180  	time.Sleep(time.Second)
   181  	passCh := make(chan struct{})
   182  	go func() {
   183  		// Unless we have a non-blocking write to ResponsesCh from readRoutine
   184  		// this blocks forever ont the waitgroup
   185  		c.Stop()
   186  		passCh <- struct{}{}
   187  	}()
   188  	select {
   189  	case <-passCh:
   190  		// Pass
   191  	case <-time.After(timeout):
   192  		t.Fatalf("WSClient did failed to stop within %v seconds - is one of the read/write routines blocking?",
   193  			timeout.Seconds())
   194  	}
   195  }
   196  
   197  func startClient(t *testing.T, addr net.Addr) *WSClient {
   198  	c := NewWSClient(addr.String(), "/websocket")
   199  	err := c.Start()
   200  	require.Nil(t, err)
   201  	c.SetLogger(log.TestingLogger())
   202  	return c
   203  }
   204  
   205  func call(t *testing.T, method string, c *WSClient) {
   206  	err := c.Call(context.Background(), method, make(map[string]interface{}))
   207  	require.NoError(t, err)
   208  }
   209  
   210  func callWgDoneOnResult(t *testing.T, c *WSClient, wg *sync.WaitGroup) {
   211  	for {
   212  		select {
   213  		case resp := <-c.ResponsesCh:
   214  			if resp.Error != nil {
   215  				t.Fatalf("unexpected error: %v", resp.Error)
   216  			}
   217  			if resp.Result != nil {
   218  				wg.Done()
   219  			}
   220  		case <-c.Quit():
   221  			return
   222  		}
   223  	}
   224  }