github.com/arcology-network/consensus-engine@v1.9.0/rpc/jsonrpc/client/ws_client.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/gorilla/websocket"
    13  	metrics "github.com/rcrowley/go-metrics"
    14  
    15  	tmrand "github.com/arcology-network/consensus-engine/libs/rand"
    16  	"github.com/arcology-network/consensus-engine/libs/service"
    17  	tmsync "github.com/arcology-network/consensus-engine/libs/sync"
    18  	types "github.com/arcology-network/consensus-engine/rpc/jsonrpc/types"
    19  )
    20  
    21  const (
    22  	defaultMaxReconnectAttempts = 25
    23  	defaultWriteWait            = 0
    24  	defaultReadWait             = 0
    25  	defaultPingPeriod           = 0
    26  )
    27  
    28  // WSClient is a JSON-RPC client, which uses WebSocket for communication with
    29  // the remote server.
    30  //
    31  // WSClient is safe for concurrent use by multiple goroutines.
    32  type WSClient struct { // nolint: maligned
    33  	conn *websocket.Conn
    34  
    35  	Address  string // IP:PORT or /path/to/socket
    36  	Endpoint string // /websocket/url/endpoint
    37  	Dialer   func(string, string) (net.Conn, error)
    38  
    39  	// Single user facing channel to read RPCResponses from, closed only when the
    40  	// client is being stopped.
    41  	ResponsesCh chan types.RPCResponse
    42  
    43  	// Callback, which will be called each time after successful reconnect.
    44  	onReconnect func()
    45  
    46  	// internal channels
    47  	send            chan types.RPCRequest // user requests
    48  	backlog         chan types.RPCRequest // stores a single user request received during a conn failure
    49  	reconnectAfter  chan error            // reconnect requests
    50  	readRoutineQuit chan struct{}         // a way for readRoutine to close writeRoutine
    51  
    52  	// Maximum reconnect attempts (0 or greater; default: 25).
    53  	maxReconnectAttempts int
    54  
    55  	// Support both ws and wss protocols
    56  	protocol string
    57  
    58  	wg sync.WaitGroup
    59  
    60  	mtx            tmsync.RWMutex
    61  	sentLastPingAt time.Time
    62  	reconnecting   bool
    63  	nextReqID      int
    64  	// sentIDs        map[types.JSONRPCIntID]bool // IDs of the requests currently in flight
    65  
    66  	// Time allowed to write a message to the server. 0 means block until operation succeeds.
    67  	writeWait time.Duration
    68  
    69  	// Time allowed to read the next message from the server. 0 means block until operation succeeds.
    70  	readWait time.Duration
    71  
    72  	// Send pings to server with this period. Must be less than readWait. If 0, no pings will be sent.
    73  	pingPeriod time.Duration
    74  
    75  	service.BaseService
    76  
    77  	// Time between sending a ping and receiving a pong. See
    78  	// https://godoc.org/github.com/rcrowley/go-metrics#Timer.
    79  	PingPongLatencyTimer metrics.Timer
    80  }
    81  
    82  // NewWS returns a new client. See the commentary on the func(*WSClient)
    83  // functions for a detailed description of how to configure ping period and
    84  // pong wait time. The endpoint argument must begin with a `/`.
    85  // An error is returned on invalid remote. The function panics when remote is nil.
    86  func NewWS(remoteAddr, endpoint string, options ...func(*WSClient)) (*WSClient, error) {
    87  	parsedURL, err := newParsedURL(remoteAddr)
    88  	if err != nil {
    89  		return nil, err
    90  	}
    91  	// default to ws protocol, unless wss is explicitly specified
    92  	if parsedURL.Scheme != protoWSS {
    93  		parsedURL.Scheme = protoWS
    94  	}
    95  
    96  	dialFn, err := makeHTTPDialer(remoteAddr)
    97  	if err != nil {
    98  		return nil, err
    99  	}
   100  
   101  	c := &WSClient{
   102  		Address:              parsedURL.GetTrimmedHostWithPath(),
   103  		Dialer:               dialFn,
   104  		Endpoint:             endpoint,
   105  		PingPongLatencyTimer: metrics.NewTimer(),
   106  
   107  		maxReconnectAttempts: defaultMaxReconnectAttempts,
   108  		readWait:             defaultReadWait,
   109  		writeWait:            defaultWriteWait,
   110  		pingPeriod:           defaultPingPeriod,
   111  		protocol:             parsedURL.Scheme,
   112  
   113  		// sentIDs: make(map[types.JSONRPCIntID]bool),
   114  	}
   115  	c.BaseService = *service.NewBaseService(nil, "WSClient", c)
   116  	for _, option := range options {
   117  		option(c)
   118  	}
   119  	return c, nil
   120  }
   121  
   122  // MaxReconnectAttempts sets the maximum number of reconnect attempts before returning an error.
   123  // It should only be used in the constructor and is not Goroutine-safe.
   124  func MaxReconnectAttempts(max int) func(*WSClient) {
   125  	return func(c *WSClient) {
   126  		c.maxReconnectAttempts = max
   127  	}
   128  }
   129  
   130  // ReadWait sets the amount of time to wait before a websocket read times out.
   131  // It should only be used in the constructor and is not Goroutine-safe.
   132  func ReadWait(readWait time.Duration) func(*WSClient) {
   133  	return func(c *WSClient) {
   134  		c.readWait = readWait
   135  	}
   136  }
   137  
   138  // WriteWait sets the amount of time to wait before a websocket write times out.
   139  // It should only be used in the constructor and is not Goroutine-safe.
   140  func WriteWait(writeWait time.Duration) func(*WSClient) {
   141  	return func(c *WSClient) {
   142  		c.writeWait = writeWait
   143  	}
   144  }
   145  
   146  // PingPeriod sets the duration for sending websocket pings.
   147  // It should only be used in the constructor - not Goroutine-safe.
   148  func PingPeriod(pingPeriod time.Duration) func(*WSClient) {
   149  	return func(c *WSClient) {
   150  		c.pingPeriod = pingPeriod
   151  	}
   152  }
   153  
   154  // OnReconnect sets the callback, which will be called every time after
   155  // successful reconnect.
   156  func OnReconnect(cb func()) func(*WSClient) {
   157  	return func(c *WSClient) {
   158  		c.onReconnect = cb
   159  	}
   160  }
   161  
   162  // String returns WS client full address.
   163  func (c *WSClient) String() string {
   164  	return fmt.Sprintf("WSClient{%s (%s)}", c.Address, c.Endpoint)
   165  }
   166  
   167  // OnStart implements service.Service by dialing a server and creating read and
   168  // write routines.
   169  func (c *WSClient) OnStart() error {
   170  	err := c.dial()
   171  	if err != nil {
   172  		return err
   173  	}
   174  
   175  	c.ResponsesCh = make(chan types.RPCResponse)
   176  
   177  	c.send = make(chan types.RPCRequest)
   178  	// 1 additional error may come from the read/write
   179  	// goroutine depending on which failed first.
   180  	c.reconnectAfter = make(chan error, 1)
   181  	// capacity for 1 request. a user won't be able to send more because the send
   182  	// channel is unbuffered.
   183  	c.backlog = make(chan types.RPCRequest, 1)
   184  
   185  	c.startReadWriteRoutines()
   186  	go c.reconnectRoutine()
   187  
   188  	return nil
   189  }
   190  
   191  // Stop overrides service.Service#Stop. There is no other way to wait until Quit
   192  // channel is closed.
   193  func (c *WSClient) Stop() error {
   194  	if err := c.BaseService.Stop(); err != nil {
   195  		return err
   196  	}
   197  	// only close user-facing channels when we can't write to them
   198  	c.wg.Wait()
   199  	close(c.ResponsesCh)
   200  
   201  	return nil
   202  }
   203  
   204  // IsReconnecting returns true if the client is reconnecting right now.
   205  func (c *WSClient) IsReconnecting() bool {
   206  	c.mtx.RLock()
   207  	defer c.mtx.RUnlock()
   208  	return c.reconnecting
   209  }
   210  
   211  // IsActive returns true if the client is running and not reconnecting.
   212  func (c *WSClient) IsActive() bool {
   213  	return c.IsRunning() && !c.IsReconnecting()
   214  }
   215  
   216  // Send the given RPC request to the server. Results will be available on
   217  // ResponsesCh, errors, if any, on ErrorsCh. Will block until send succeeds or
   218  // ctx.Done is closed.
   219  func (c *WSClient) Send(ctx context.Context, request types.RPCRequest) error {
   220  	select {
   221  	case c.send <- request:
   222  		c.Logger.Info("sent a request", "req", request)
   223  		// c.mtx.Lock()
   224  		// c.sentIDs[request.ID.(types.JSONRPCIntID)] = true
   225  		// c.mtx.Unlock()
   226  		return nil
   227  	case <-ctx.Done():
   228  		return ctx.Err()
   229  	}
   230  }
   231  
   232  // Call enqueues a call request onto the Send queue. Requests are JSON encoded.
   233  func (c *WSClient) Call(ctx context.Context, method string, params map[string]interface{}) error {
   234  	request, err := types.MapToRequest(c.nextRequestID(), method, params)
   235  	if err != nil {
   236  		return err
   237  	}
   238  	return c.Send(ctx, request)
   239  }
   240  
   241  // CallWithArrayParams enqueues a call request onto the Send queue. Params are
   242  // in a form of array (e.g. []interface{}{"abcd"}). Requests are JSON encoded.
   243  func (c *WSClient) CallWithArrayParams(ctx context.Context, method string, params []interface{}) error {
   244  	request, err := types.ArrayToRequest(c.nextRequestID(), method, params)
   245  	if err != nil {
   246  		return err
   247  	}
   248  	return c.Send(ctx, request)
   249  }
   250  
   251  // Private methods
   252  
   253  func (c *WSClient) nextRequestID() types.JSONRPCIntID {
   254  	c.mtx.Lock()
   255  	id := c.nextReqID
   256  	c.nextReqID++
   257  	c.mtx.Unlock()
   258  	return types.JSONRPCIntID(id)
   259  }
   260  
   261  func (c *WSClient) dial() error {
   262  	dialer := &websocket.Dialer{
   263  		NetDial: c.Dialer,
   264  		Proxy:   http.ProxyFromEnvironment,
   265  	}
   266  	rHeader := http.Header{}
   267  	conn, _, err := dialer.Dial(c.protocol+"://"+c.Address+c.Endpoint, rHeader) // nolint:bodyclose
   268  	if err != nil {
   269  		return err
   270  	}
   271  	c.conn = conn
   272  	return nil
   273  }
   274  
   275  // reconnect tries to redial up to maxReconnectAttempts with exponential
   276  // backoff.
   277  func (c *WSClient) reconnect() error {
   278  	attempt := 0
   279  
   280  	c.mtx.Lock()
   281  	c.reconnecting = true
   282  	c.mtx.Unlock()
   283  	defer func() {
   284  		c.mtx.Lock()
   285  		c.reconnecting = false
   286  		c.mtx.Unlock()
   287  	}()
   288  
   289  	for {
   290  		jitter := time.Duration(tmrand.Float64() * float64(time.Second)) // 1s == (1e9 ns)
   291  		backoffDuration := jitter + ((1 << uint(attempt)) * time.Second)
   292  
   293  		c.Logger.Info("reconnecting", "attempt", attempt+1, "backoff_duration", backoffDuration)
   294  		time.Sleep(backoffDuration)
   295  
   296  		err := c.dial()
   297  		if err != nil {
   298  			c.Logger.Error("failed to redial", "err", err)
   299  		} else {
   300  			c.Logger.Info("reconnected")
   301  			if c.onReconnect != nil {
   302  				go c.onReconnect()
   303  			}
   304  			return nil
   305  		}
   306  
   307  		attempt++
   308  
   309  		if attempt > c.maxReconnectAttempts {
   310  			return fmt.Errorf("reached maximum reconnect attempts: %w", err)
   311  		}
   312  	}
   313  }
   314  
   315  func (c *WSClient) startReadWriteRoutines() {
   316  	c.wg.Add(2)
   317  	c.readRoutineQuit = make(chan struct{})
   318  	go c.readRoutine()
   319  	go c.writeRoutine()
   320  }
   321  
   322  func (c *WSClient) processBacklog() error {
   323  	select {
   324  	case request := <-c.backlog:
   325  		if c.writeWait > 0 {
   326  			if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeWait)); err != nil {
   327  				c.Logger.Error("failed to set write deadline", "err", err)
   328  			}
   329  		}
   330  		if err := c.conn.WriteJSON(request); err != nil {
   331  			c.Logger.Error("failed to resend request", "err", err)
   332  			c.reconnectAfter <- err
   333  			// requeue request
   334  			c.backlog <- request
   335  			return err
   336  		}
   337  		c.Logger.Info("resend a request", "req", request)
   338  	default:
   339  	}
   340  	return nil
   341  }
   342  
   343  func (c *WSClient) reconnectRoutine() {
   344  	for {
   345  		select {
   346  		case originalError := <-c.reconnectAfter:
   347  			// wait until writeRoutine and readRoutine finish
   348  			c.wg.Wait()
   349  			if err := c.reconnect(); err != nil {
   350  				c.Logger.Error("failed to reconnect", "err", err, "original_err", originalError)
   351  				if err = c.Stop(); err != nil {
   352  					c.Logger.Error("failed to stop conn", "error", err)
   353  				}
   354  
   355  				return
   356  			}
   357  			// drain reconnectAfter
   358  		LOOP:
   359  			for {
   360  				select {
   361  				case <-c.reconnectAfter:
   362  				default:
   363  					break LOOP
   364  				}
   365  			}
   366  			err := c.processBacklog()
   367  			if err == nil {
   368  				c.startReadWriteRoutines()
   369  			}
   370  
   371  		case <-c.Quit():
   372  			return
   373  		}
   374  	}
   375  }
   376  
   377  // The client ensures that there is at most one writer to a connection by
   378  // executing all writes from this goroutine.
   379  func (c *WSClient) writeRoutine() {
   380  	var ticker *time.Ticker
   381  	if c.pingPeriod > 0 {
   382  		// ticker with a predefined period
   383  		ticker = time.NewTicker(c.pingPeriod)
   384  	} else {
   385  		// ticker that never fires
   386  		ticker = &time.Ticker{C: make(<-chan time.Time)}
   387  	}
   388  
   389  	defer func() {
   390  		ticker.Stop()
   391  		c.conn.Close()
   392  		// err != nil {
   393  		// ignore error; it will trigger in tests
   394  		// likely because it's closing an already closed connection
   395  		// }
   396  		c.wg.Done()
   397  	}()
   398  
   399  	for {
   400  		select {
   401  		case request := <-c.send:
   402  			if c.writeWait > 0 {
   403  				if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeWait)); err != nil {
   404  					c.Logger.Error("failed to set write deadline", "err", err)
   405  				}
   406  			}
   407  			if err := c.conn.WriteJSON(request); err != nil {
   408  				c.Logger.Error("failed to send request", "err", err)
   409  				c.reconnectAfter <- err
   410  				// add request to the backlog, so we don't lose it
   411  				c.backlog <- request
   412  				return
   413  			}
   414  		case <-ticker.C:
   415  			if c.writeWait > 0 {
   416  				if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeWait)); err != nil {
   417  					c.Logger.Error("failed to set write deadline", "err", err)
   418  				}
   419  			}
   420  			if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
   421  				c.Logger.Error("failed to write ping", "err", err)
   422  				c.reconnectAfter <- err
   423  				return
   424  			}
   425  			c.mtx.Lock()
   426  			c.sentLastPingAt = time.Now()
   427  			c.mtx.Unlock()
   428  			c.Logger.Debug("sent ping")
   429  		case <-c.readRoutineQuit:
   430  			return
   431  		case <-c.Quit():
   432  			if err := c.conn.WriteMessage(
   433  				websocket.CloseMessage,
   434  				websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
   435  			); err != nil {
   436  				c.Logger.Error("failed to write message", "err", err)
   437  			}
   438  			return
   439  		}
   440  	}
   441  }
   442  
   443  // The client ensures that there is at most one reader to a connection by
   444  // executing all reads from this goroutine.
   445  func (c *WSClient) readRoutine() {
   446  	defer func() {
   447  		c.conn.Close()
   448  		// err != nil {
   449  		// ignore error; it will trigger in tests
   450  		// likely because it's closing an already closed connection
   451  		// }
   452  		c.wg.Done()
   453  	}()
   454  
   455  	c.conn.SetPongHandler(func(string) error {
   456  		// gather latency stats
   457  		c.mtx.RLock()
   458  		t := c.sentLastPingAt
   459  		c.mtx.RUnlock()
   460  		c.PingPongLatencyTimer.UpdateSince(t)
   461  
   462  		c.Logger.Debug("got pong")
   463  		return nil
   464  	})
   465  
   466  	for {
   467  		// reset deadline for every message type (control or data)
   468  		if c.readWait > 0 {
   469  			if err := c.conn.SetReadDeadline(time.Now().Add(c.readWait)); err != nil {
   470  				c.Logger.Error("failed to set read deadline", "err", err)
   471  			}
   472  		}
   473  		_, data, err := c.conn.ReadMessage()
   474  		if err != nil {
   475  			if !websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
   476  				return
   477  			}
   478  
   479  			c.Logger.Error("failed to read response", "err", err)
   480  			close(c.readRoutineQuit)
   481  			c.reconnectAfter <- err
   482  			return
   483  		}
   484  
   485  		var response types.RPCResponse
   486  		err = json.Unmarshal(data, &response)
   487  		if err != nil {
   488  			c.Logger.Error("failed to parse response", "err", err, "data", string(data))
   489  			continue
   490  		}
   491  
   492  		if err = validateResponseID(response.ID); err != nil {
   493  			c.Logger.Error("error in response ID", "id", response.ID, "err", err)
   494  			continue
   495  		}
   496  
   497  		// TODO: events resulting from /subscribe do not work with ->
   498  		// because they are implemented as responses with the subscribe request's
   499  		// ID. According to the spec, they should be notifications (requests
   500  		// without IDs).
   501  		// https://github.com/arcology-network/consensus-engine/issues/2949
   502  		// c.mtx.Lock()
   503  		// if _, ok := c.sentIDs[response.ID.(types.JSONRPCIntID)]; !ok {
   504  		// 	c.Logger.Error("unsolicited response ID", "id", response.ID, "expected", c.sentIDs)
   505  		// 	c.mtx.Unlock()
   506  		// 	continue
   507  		// }
   508  		// delete(c.sentIDs, response.ID.(types.JSONRPCIntID))
   509  		// c.mtx.Unlock()
   510  		// Combine a non-blocking read on BaseService.Quit with a non-blocking write on ResponsesCh to avoid blocking
   511  		// c.wg.Wait() in c.Stop(). Note we rely on Quit being closed so that it sends unlimited Quit signals to stop
   512  		// both readRoutine and writeRoutine
   513  
   514  		c.Logger.Info("got response", "id", response.ID, "result", fmt.Sprintf("%X", response.Result))
   515  
   516  		select {
   517  		case <-c.Quit():
   518  		case c.ResponsesCh <- response:
   519  		}
   520  	}
   521  }
   522  
   523  // Predefined methods
   524  
   525  // Subscribe to a query. Note the server must have a "subscribe" route
   526  // defined.
   527  func (c *WSClient) Subscribe(ctx context.Context, query string) error {
   528  	params := map[string]interface{}{"query": query}
   529  	return c.Call(ctx, "subscribe", params)
   530  }
   531  
   532  // Unsubscribe from a query. Note the server must have a "unsubscribe" route
   533  // defined.
   534  func (c *WSClient) Unsubscribe(ctx context.Context, query string) error {
   535  	params := map[string]interface{}{"query": query}
   536  	return c.Call(ctx, "unsubscribe", params)
   537  }
   538  
   539  // UnsubscribeAll from all. Note the server must have a "unsubscribe_all" route
   540  // defined.
   541  func (c *WSClient) UnsubscribeAll(ctx context.Context) error {
   542  	params := map[string]interface{}{}
   543  	return c.Call(ctx, "unsubscribe_all", params)
   544  }