github.com/evdatsion/aphelion-dpos-bft@v0.32.1/rpc/lib/client/ws_client.go (about)

     1  package rpcclient
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/gorilla/websocket"
    13  	"github.com/pkg/errors"
    14  	metrics "github.com/rcrowley/go-metrics"
    15  
    16  	amino "github.com/evdatsion/go-amino"
    17  	cmn "github.com/evdatsion/aphelion-dpos-bft/libs/common"
    18  	types "github.com/evdatsion/aphelion-dpos-bft/rpc/lib/types"
    19  )
    20  
    21  const (
    22  	defaultMaxReconnectAttempts = 25
    23  	defaultWriteWait            = 0
    24  	defaultReadWait             = 0
    25  	defaultPingPeriod           = 0
    26  )
    27  
    28  // WSClient is a WebSocket client. The methods of WSClient are safe for use by
    29  // multiple goroutines.
    30  type WSClient struct {
    31  	cmn.BaseService
    32  
    33  	conn *websocket.Conn
    34  	cdc  *amino.Codec
    35  
    36  	Address  string // IP:PORT or /path/to/socket
    37  	Endpoint string // /websocket/url/endpoint
    38  	Dialer   func(string, string) (net.Conn, error)
    39  
    40  	// Time between sending a ping and receiving a pong. See
    41  	// https://godoc.org/github.com/rcrowley/go-metrics#Timer.
    42  	PingPongLatencyTimer metrics.Timer
    43  
    44  	// Single user facing channel to read RPCResponses from, closed only when the client is being stopped.
    45  	ResponsesCh chan types.RPCResponse
    46  
    47  	// Callback, which will be called each time after successful reconnect.
    48  	onReconnect func()
    49  
    50  	// internal channels
    51  	send            chan types.RPCRequest // user requests
    52  	backlog         chan types.RPCRequest // stores a single user request received during a conn failure
    53  	reconnectAfter  chan error            // reconnect requests
    54  	readRoutineQuit chan struct{}         // a way for readRoutine to close writeRoutine
    55  
    56  	wg sync.WaitGroup
    57  
    58  	mtx            sync.RWMutex
    59  	sentLastPingAt time.Time
    60  	reconnecting   bool
    61  
    62  	// Maximum reconnect attempts (0 or greater; default: 25).
    63  	maxReconnectAttempts int
    64  
    65  	// Time allowed to write a message to the server. 0 means block until operation succeeds.
    66  	writeWait time.Duration
    67  
    68  	// Time allowed to read the next message from the server. 0 means block until operation succeeds.
    69  	readWait time.Duration
    70  
    71  	// Send pings to server with this period. Must be less than readWait. If 0, no pings will be sent.
    72  	pingPeriod time.Duration
    73  
    74  	// Support both ws and wss protocols
    75  	protocol string
    76  }
    77  
    78  // NewWSClient returns a new client. See the commentary on the func(*WSClient)
    79  // functions for a detailed description of how to configure ping period and
    80  // pong wait time. The endpoint argument must begin with a `/`.
    81  func NewWSClient(remoteAddr, endpoint string, options ...func(*WSClient)) *WSClient {
    82  	protocol, addr, dialer := makeHTTPDialer(remoteAddr)
    83  	// default to ws protocol, unless wss is explicitly specified
    84  	if protocol != "wss" {
    85  		protocol = "ws"
    86  	}
    87  
    88  	c := &WSClient{
    89  		cdc:                  amino.NewCodec(),
    90  		Address:              addr,
    91  		Dialer:               dialer,
    92  		Endpoint:             endpoint,
    93  		PingPongLatencyTimer: metrics.NewTimer(),
    94  
    95  		maxReconnectAttempts: defaultMaxReconnectAttempts,
    96  		readWait:             defaultReadWait,
    97  		writeWait:            defaultWriteWait,
    98  		pingPeriod:           defaultPingPeriod,
    99  		protocol:             protocol,
   100  	}
   101  	c.BaseService = *cmn.NewBaseService(nil, "WSClient", c)
   102  	for _, option := range options {
   103  		option(c)
   104  	}
   105  	return c
   106  }
   107  
   108  // MaxReconnectAttempts sets the maximum number of reconnect attempts before returning an error.
   109  // It should only be used in the constructor and is not Goroutine-safe.
   110  func MaxReconnectAttempts(max int) func(*WSClient) {
   111  	return func(c *WSClient) {
   112  		c.maxReconnectAttempts = max
   113  	}
   114  }
   115  
   116  // ReadWait sets the amount of time to wait before a websocket read times out.
   117  // It should only be used in the constructor and is not Goroutine-safe.
   118  func ReadWait(readWait time.Duration) func(*WSClient) {
   119  	return func(c *WSClient) {
   120  		c.readWait = readWait
   121  	}
   122  }
   123  
   124  // WriteWait sets the amount of time to wait before a websocket write times out.
   125  // It should only be used in the constructor and is not Goroutine-safe.
   126  func WriteWait(writeWait time.Duration) func(*WSClient) {
   127  	return func(c *WSClient) {
   128  		c.writeWait = writeWait
   129  	}
   130  }
   131  
   132  // PingPeriod sets the duration for sending websocket pings.
   133  // It should only be used in the constructor - not Goroutine-safe.
   134  func PingPeriod(pingPeriod time.Duration) func(*WSClient) {
   135  	return func(c *WSClient) {
   136  		c.pingPeriod = pingPeriod
   137  	}
   138  }
   139  
   140  // OnReconnect sets the callback, which will be called every time after
   141  // successful reconnect.
   142  func OnReconnect(cb func()) func(*WSClient) {
   143  	return func(c *WSClient) {
   144  		c.onReconnect = cb
   145  	}
   146  }
   147  
   148  // String returns WS client full address.
   149  func (c *WSClient) String() string {
   150  	return fmt.Sprintf("%s (%s)", c.Address, c.Endpoint)
   151  }
   152  
   153  // OnStart implements cmn.Service by dialing a server and creating read and
   154  // write routines.
   155  func (c *WSClient) OnStart() error {
   156  	err := c.dial()
   157  	if err != nil {
   158  		return err
   159  	}
   160  
   161  	c.ResponsesCh = make(chan types.RPCResponse)
   162  
   163  	c.send = make(chan types.RPCRequest)
   164  	// 1 additional error may come from the read/write
   165  	// goroutine depending on which failed first.
   166  	c.reconnectAfter = make(chan error, 1)
   167  	// capacity for 1 request. a user won't be able to send more because the send
   168  	// channel is unbuffered.
   169  	c.backlog = make(chan types.RPCRequest, 1)
   170  
   171  	c.startReadWriteRoutines()
   172  	go c.reconnectRoutine()
   173  
   174  	return nil
   175  }
   176  
   177  // Stop overrides cmn.Service#Stop. There is no other way to wait until Quit
   178  // channel is closed.
   179  func (c *WSClient) Stop() error {
   180  	if err := c.BaseService.Stop(); err != nil {
   181  		return err
   182  	}
   183  	// only close user-facing channels when we can't write to them
   184  	c.wg.Wait()
   185  	close(c.ResponsesCh)
   186  
   187  	return nil
   188  }
   189  
   190  // IsReconnecting returns true if the client is reconnecting right now.
   191  func (c *WSClient) IsReconnecting() bool {
   192  	c.mtx.RLock()
   193  	defer c.mtx.RUnlock()
   194  	return c.reconnecting
   195  }
   196  
   197  // IsActive returns true if the client is running and not reconnecting.
   198  func (c *WSClient) IsActive() bool {
   199  	return c.IsRunning() && !c.IsReconnecting()
   200  }
   201  
   202  // Send the given RPC request to the server. Results will be available on
   203  // ResponsesCh, errors, if any, on ErrorsCh. Will block until send succeeds or
   204  // ctx.Done is closed.
   205  func (c *WSClient) Send(ctx context.Context, request types.RPCRequest) error {
   206  	select {
   207  	case c.send <- request:
   208  		c.Logger.Info("sent a request", "req", request)
   209  		return nil
   210  	case <-ctx.Done():
   211  		return ctx.Err()
   212  	}
   213  }
   214  
   215  // Call the given method. See Send description.
   216  func (c *WSClient) Call(ctx context.Context, method string, params map[string]interface{}) error {
   217  	request, err := types.MapToRequest(c.cdc, types.JSONRPCStringID("ws-client"), method, params)
   218  	if err != nil {
   219  		return err
   220  	}
   221  	return c.Send(ctx, request)
   222  }
   223  
   224  // CallWithArrayParams the given method with params in a form of array. See
   225  // Send description.
   226  func (c *WSClient) CallWithArrayParams(ctx context.Context, method string, params []interface{}) error {
   227  	request, err := types.ArrayToRequest(c.cdc, types.JSONRPCStringID("ws-client"), method, params)
   228  	if err != nil {
   229  		return err
   230  	}
   231  	return c.Send(ctx, request)
   232  }
   233  
   234  func (c *WSClient) Codec() *amino.Codec {
   235  	return c.cdc
   236  }
   237  
   238  func (c *WSClient) SetCodec(cdc *amino.Codec) {
   239  	c.cdc = cdc
   240  }
   241  
   242  ///////////////////////////////////////////////////////////////////////////////
   243  // Private methods
   244  
   245  func (c *WSClient) dial() error {
   246  	dialer := &websocket.Dialer{
   247  		NetDial: c.Dialer,
   248  		Proxy:   http.ProxyFromEnvironment,
   249  	}
   250  	rHeader := http.Header{}
   251  	conn, _, err := dialer.Dial(c.protocol+"://"+c.Address+c.Endpoint, rHeader)
   252  	if err != nil {
   253  		return err
   254  	}
   255  	c.conn = conn
   256  	return nil
   257  }
   258  
   259  // reconnect tries to redial up to maxReconnectAttempts with exponential
   260  // backoff.
   261  func (c *WSClient) reconnect() error {
   262  	attempt := 0
   263  
   264  	c.mtx.Lock()
   265  	c.reconnecting = true
   266  	c.mtx.Unlock()
   267  	defer func() {
   268  		c.mtx.Lock()
   269  		c.reconnecting = false
   270  		c.mtx.Unlock()
   271  	}()
   272  
   273  	for {
   274  		jitterSeconds := time.Duration(cmn.RandFloat64() * float64(time.Second)) // 1s == (1e9 ns)
   275  		backoffDuration := jitterSeconds + ((1 << uint(attempt)) * time.Second)
   276  
   277  		c.Logger.Info("reconnecting", "attempt", attempt+1, "backoff_duration", backoffDuration)
   278  		time.Sleep(backoffDuration)
   279  
   280  		err := c.dial()
   281  		if err != nil {
   282  			c.Logger.Error("failed to redial", "err", err)
   283  		} else {
   284  			c.Logger.Info("reconnected")
   285  			if c.onReconnect != nil {
   286  				go c.onReconnect()
   287  			}
   288  			return nil
   289  		}
   290  
   291  		attempt++
   292  
   293  		if attempt > c.maxReconnectAttempts {
   294  			return errors.Wrap(err, "reached maximum reconnect attempts")
   295  		}
   296  	}
   297  }
   298  
   299  func (c *WSClient) startReadWriteRoutines() {
   300  	c.wg.Add(2)
   301  	c.readRoutineQuit = make(chan struct{})
   302  	go c.readRoutine()
   303  	go c.writeRoutine()
   304  }
   305  
   306  func (c *WSClient) processBacklog() error {
   307  	select {
   308  	case request := <-c.backlog:
   309  		if c.writeWait > 0 {
   310  			if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeWait)); err != nil {
   311  				c.Logger.Error("failed to set write deadline", "err", err)
   312  			}
   313  		}
   314  		if err := c.conn.WriteJSON(request); err != nil {
   315  			c.Logger.Error("failed to resend request", "err", err)
   316  			c.reconnectAfter <- err
   317  			// requeue request
   318  			c.backlog <- request
   319  			return err
   320  		}
   321  		c.Logger.Info("resend a request", "req", request)
   322  	default:
   323  	}
   324  	return nil
   325  }
   326  
   327  func (c *WSClient) reconnectRoutine() {
   328  	for {
   329  		select {
   330  		case originalError := <-c.reconnectAfter:
   331  			// wait until writeRoutine and readRoutine finish
   332  			c.wg.Wait()
   333  			if err := c.reconnect(); err != nil {
   334  				c.Logger.Error("failed to reconnect", "err", err, "original_err", originalError)
   335  				c.Stop()
   336  				return
   337  			}
   338  			// drain reconnectAfter
   339  		LOOP:
   340  			for {
   341  				select {
   342  				case <-c.reconnectAfter:
   343  				default:
   344  					break LOOP
   345  				}
   346  			}
   347  			err := c.processBacklog()
   348  			if err == nil {
   349  				c.startReadWriteRoutines()
   350  			}
   351  
   352  		case <-c.Quit():
   353  			return
   354  		}
   355  	}
   356  }
   357  
   358  // The client ensures that there is at most one writer to a connection by
   359  // executing all writes from this goroutine.
   360  func (c *WSClient) writeRoutine() {
   361  	var ticker *time.Ticker
   362  	if c.pingPeriod > 0 {
   363  		// ticker with a predefined period
   364  		ticker = time.NewTicker(c.pingPeriod)
   365  	} else {
   366  		// ticker that never fires
   367  		ticker = &time.Ticker{C: make(<-chan time.Time)}
   368  	}
   369  
   370  	defer func() {
   371  		ticker.Stop()
   372  		if err := c.conn.Close(); err != nil {
   373  			// ignore error; it will trigger in tests
   374  			// likely because it's closing an already closed connection
   375  		}
   376  		c.wg.Done()
   377  	}()
   378  
   379  	for {
   380  		select {
   381  		case request := <-c.send:
   382  			if c.writeWait > 0 {
   383  				if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeWait)); err != nil {
   384  					c.Logger.Error("failed to set write deadline", "err", err)
   385  				}
   386  			}
   387  			if err := c.conn.WriteJSON(request); err != nil {
   388  				c.Logger.Error("failed to send request", "err", err)
   389  				c.reconnectAfter <- err
   390  				// add request to the backlog, so we don't lose it
   391  				c.backlog <- request
   392  				return
   393  			}
   394  		case <-ticker.C:
   395  			if c.writeWait > 0 {
   396  				if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeWait)); err != nil {
   397  					c.Logger.Error("failed to set write deadline", "err", err)
   398  				}
   399  			}
   400  			if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
   401  				c.Logger.Error("failed to write ping", "err", err)
   402  				c.reconnectAfter <- err
   403  				return
   404  			}
   405  			c.mtx.Lock()
   406  			c.sentLastPingAt = time.Now()
   407  			c.mtx.Unlock()
   408  			c.Logger.Debug("sent ping")
   409  		case <-c.readRoutineQuit:
   410  			return
   411  		case <-c.Quit():
   412  			if err := c.conn.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")); err != nil {
   413  				c.Logger.Error("failed to write message", "err", err)
   414  			}
   415  			return
   416  		}
   417  	}
   418  }
   419  
   420  // The client ensures that there is at most one reader to a connection by
   421  // executing all reads from this goroutine.
   422  func (c *WSClient) readRoutine() {
   423  	defer func() {
   424  		if err := c.conn.Close(); err != nil {
   425  			// ignore error; it will trigger in tests
   426  			// likely because it's closing an already closed connection
   427  		}
   428  		c.wg.Done()
   429  	}()
   430  
   431  	c.conn.SetPongHandler(func(string) error {
   432  		// gather latency stats
   433  		c.mtx.RLock()
   434  		t := c.sentLastPingAt
   435  		c.mtx.RUnlock()
   436  		c.PingPongLatencyTimer.UpdateSince(t)
   437  
   438  		c.Logger.Debug("got pong")
   439  		return nil
   440  	})
   441  
   442  	for {
   443  		// reset deadline for every message type (control or data)
   444  		if c.readWait > 0 {
   445  			if err := c.conn.SetReadDeadline(time.Now().Add(c.readWait)); err != nil {
   446  				c.Logger.Error("failed to set read deadline", "err", err)
   447  			}
   448  		}
   449  		_, data, err := c.conn.ReadMessage()
   450  		if err != nil {
   451  			if !websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
   452  				return
   453  			}
   454  
   455  			c.Logger.Error("failed to read response", "err", err)
   456  			close(c.readRoutineQuit)
   457  			c.reconnectAfter <- err
   458  			return
   459  		}
   460  
   461  		var response types.RPCResponse
   462  		err = json.Unmarshal(data, &response)
   463  		if err != nil {
   464  			c.Logger.Error("failed to parse response", "err", err, "data", string(data))
   465  			continue
   466  		}
   467  		c.Logger.Info("got response", "resp", response.Result)
   468  		// Combine a non-blocking read on BaseService.Quit with a non-blocking write on ResponsesCh to avoid blocking
   469  		// c.wg.Wait() in c.Stop(). Note we rely on Quit being closed so that it sends unlimited Quit signals to stop
   470  		// both readRoutine and writeRoutine
   471  		select {
   472  		case <-c.Quit():
   473  		case c.ResponsesCh <- response:
   474  		}
   475  	}
   476  }
   477  
   478  ///////////////////////////////////////////////////////////////////////////////
   479  // Predefined methods
   480  
   481  // Subscribe to a query. Note the server must have a "subscribe" route
   482  // defined.
   483  func (c *WSClient) Subscribe(ctx context.Context, query string) error {
   484  	params := map[string]interface{}{"query": query}
   485  	return c.Call(ctx, "subscribe", params)
   486  }
   487  
   488  // Unsubscribe from a query. Note the server must have a "unsubscribe" route
   489  // defined.
   490  func (c *WSClient) Unsubscribe(ctx context.Context, query string) error {
   491  	params := map[string]interface{}{"query": query}
   492  	return c.Call(ctx, "unsubscribe", params)
   493  }
   494  
   495  // UnsubscribeAll from all. Note the server must have a "unsubscribe_all" route
   496  // defined.
   497  func (c *WSClient) UnsubscribeAll(ctx context.Context) error {
   498  	params := map[string]interface{}{}
   499  	return c.Call(ctx, "unsubscribe_all", params)
   500  }