github.com/vipernet-xyz/tm@v0.34.24/rpc/jsonrpc/client/ws_client.go (about)

     1  package client
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"net"
     8  	"net/http"
     9  	"sync"
    10  	"time"
    11  
    12  	"github.com/gorilla/websocket"
    13  	metrics "github.com/rcrowley/go-metrics"
    14  
    15  	"github.com/vipernet-xyz/tm/libs/log"
    16  	tmrand "github.com/vipernet-xyz/tm/libs/rand"
    17  	"github.com/vipernet-xyz/tm/libs/service"
    18  	tmsync "github.com/vipernet-xyz/tm/libs/sync"
    19  	types "github.com/vipernet-xyz/tm/rpc/jsonrpc/types"
    20  )
    21  
    22  const (
    23  	defaultMaxReconnectAttempts = 25
    24  	defaultWriteWait            = 0
    25  	defaultReadWait             = 0
    26  	defaultPingPeriod           = 0
    27  )
    28  
    29  // WSClient is a JSON-RPC client, which uses WebSocket for communication with
    30  // the remote server.
    31  //
    32  // WSClient is safe for concurrent use by multiple goroutines.
    33  type WSClient struct { //nolint: maligned
    34  	conn *websocket.Conn
    35  
    36  	Address  string // IP:PORT or /path/to/socket
    37  	Endpoint string // /websocket/url/endpoint
    38  	Dialer   func(string, string) (net.Conn, error)
    39  
    40  	// Single user facing channel to read RPCResponses from, closed only when the
    41  	// client is being stopped.
    42  	ResponsesCh chan types.RPCResponse
    43  
    44  	// Callback, which will be called each time after successful reconnect.
    45  	onReconnect func()
    46  
    47  	// internal channels
    48  	send            chan types.RPCRequest // user requests
    49  	backlog         chan types.RPCRequest // stores a single user request received during a conn failure
    50  	reconnectAfter  chan error            // reconnect requests
    51  	readRoutineQuit chan struct{}         // a way for readRoutine to close writeRoutine
    52  
    53  	// Maximum reconnect attempts (0 or greater; default: 25).
    54  	maxReconnectAttempts int
    55  
    56  	// Support both ws and wss protocols
    57  	protocol string
    58  
    59  	wg sync.WaitGroup
    60  
    61  	mtx            tmsync.RWMutex
    62  	sentLastPingAt time.Time
    63  	reconnecting   bool
    64  	nextReqID      int
    65  	// sentIDs        map[types.JSONRPCIntID]bool // IDs of the requests currently in flight
    66  
    67  	// Time allowed to write a message to the server. 0 means block until operation succeeds.
    68  	writeWait time.Duration
    69  
    70  	// Time allowed to read the next message from the server. 0 means block until operation succeeds.
    71  	readWait time.Duration
    72  
    73  	// Send pings to server with this period. Must be less than readWait. If 0, no pings will be sent.
    74  	pingPeriod time.Duration
    75  
    76  	service.BaseService
    77  
    78  	// Time between sending a ping and receiving a pong. See
    79  	// https://godoc.org/github.com/rcrowley/go-metrics#Timer.
    80  	PingPongLatencyTimer metrics.Timer
    81  }
    82  
    83  // NewWS returns a new client. See the commentary on the func(*WSClient)
    84  // functions for a detailed description of how to configure ping period and
    85  // pong wait time. The endpoint argument must begin with a `/`.
    86  // An error is returned on invalid remote. The function panics when remote is nil.
    87  func NewWS(remoteAddr, endpoint string, options ...func(*WSClient)) (*WSClient, error) {
    88  	parsedURL, err := newParsedURL(remoteAddr)
    89  	if err != nil {
    90  		return nil, err
    91  	}
    92  	// default to ws protocol, unless wss or https is specified
    93  	if parsedURL.Scheme == protoHTTPS {
    94  		parsedURL.Scheme = protoWSS
    95  	} else if parsedURL.Scheme != protoWSS {
    96  		parsedURL.Scheme = protoWS
    97  	}
    98  
    99  	dialFn, err := makeHTTPDialer(remoteAddr)
   100  	if err != nil {
   101  		return nil, err
   102  	}
   103  
   104  	c := &WSClient{
   105  		Address:              parsedURL.GetTrimmedHostWithPath(),
   106  		Dialer:               dialFn,
   107  		Endpoint:             endpoint,
   108  		PingPongLatencyTimer: metrics.NewTimer(),
   109  
   110  		maxReconnectAttempts: defaultMaxReconnectAttempts,
   111  		readWait:             defaultReadWait,
   112  		writeWait:            defaultWriteWait,
   113  		pingPeriod:           defaultPingPeriod,
   114  		protocol:             parsedURL.Scheme,
   115  
   116  		// sentIDs: make(map[types.JSONRPCIntID]bool),
   117  	}
   118  	c.BaseService = *service.NewBaseService(nil, "WSClient", c)
   119  	for _, option := range options {
   120  		option(c)
   121  	}
   122  	return c, nil
   123  }
   124  
   125  // MaxReconnectAttempts sets the maximum number of reconnect attempts before returning an error.
   126  // It should only be used in the constructor and is not Goroutine-safe.
   127  func MaxReconnectAttempts(max int) func(*WSClient) {
   128  	return func(c *WSClient) {
   129  		c.maxReconnectAttempts = max
   130  	}
   131  }
   132  
   133  // ReadWait sets the amount of time to wait before a websocket read times out.
   134  // It should only be used in the constructor and is not Goroutine-safe.
   135  func ReadWait(readWait time.Duration) func(*WSClient) {
   136  	return func(c *WSClient) {
   137  		c.readWait = readWait
   138  	}
   139  }
   140  
   141  // WriteWait sets the amount of time to wait before a websocket write times out.
   142  // It should only be used in the constructor and is not Goroutine-safe.
   143  func WriteWait(writeWait time.Duration) func(*WSClient) {
   144  	return func(c *WSClient) {
   145  		c.writeWait = writeWait
   146  	}
   147  }
   148  
   149  // PingPeriod sets the duration for sending websocket pings.
   150  // It should only be used in the constructor - not Goroutine-safe.
   151  func PingPeriod(pingPeriod time.Duration) func(*WSClient) {
   152  	return func(c *WSClient) {
   153  		c.pingPeriod = pingPeriod
   154  	}
   155  }
   156  
   157  // OnReconnect sets the callback, which will be called every time after
   158  // successful reconnect.
   159  func OnReconnect(cb func()) func(*WSClient) {
   160  	return func(c *WSClient) {
   161  		c.onReconnect = cb
   162  	}
   163  }
   164  
   165  // String returns WS client full address.
   166  func (c *WSClient) String() string {
   167  	return fmt.Sprintf("WSClient{%s (%s)}", c.Address, c.Endpoint)
   168  }
   169  
   170  // OnStart implements service.Service by dialing a server and creating read and
   171  // write routines.
   172  func (c *WSClient) OnStart() error {
   173  	err := c.dial()
   174  	if err != nil {
   175  		return err
   176  	}
   177  
   178  	c.ResponsesCh = make(chan types.RPCResponse)
   179  
   180  	c.send = make(chan types.RPCRequest)
   181  	// 1 additional error may come from the read/write
   182  	// goroutine depending on which failed first.
   183  	c.reconnectAfter = make(chan error, 1)
   184  	// capacity for 1 request. a user won't be able to send more because the send
   185  	// channel is unbuffered.
   186  	c.backlog = make(chan types.RPCRequest, 1)
   187  
   188  	c.startReadWriteRoutines()
   189  	go c.reconnectRoutine()
   190  
   191  	return nil
   192  }
   193  
   194  // Stop overrides service.Service#Stop. There is no other way to wait until Quit
   195  // channel is closed.
   196  func (c *WSClient) Stop() error {
   197  	if err := c.BaseService.Stop(); err != nil {
   198  		return err
   199  	}
   200  	// only close user-facing channels when we can't write to them
   201  	c.wg.Wait()
   202  	close(c.ResponsesCh)
   203  
   204  	return nil
   205  }
   206  
   207  // IsReconnecting returns true if the client is reconnecting right now.
   208  func (c *WSClient) IsReconnecting() bool {
   209  	c.mtx.RLock()
   210  	defer c.mtx.RUnlock()
   211  	return c.reconnecting
   212  }
   213  
   214  // IsActive returns true if the client is running and not reconnecting.
   215  func (c *WSClient) IsActive() bool {
   216  	return c.IsRunning() && !c.IsReconnecting()
   217  }
   218  
   219  // Send the given RPC request to the server. Results will be available on
   220  // ResponsesCh, errors, if any, on ErrorsCh. Will block until send succeeds or
   221  // ctx.Done is closed.
   222  func (c *WSClient) Send(ctx context.Context, request types.RPCRequest) error {
   223  	select {
   224  	case c.send <- request:
   225  		c.Logger.Info("sent a request", "req", request)
   226  		// c.mtx.Lock()
   227  		// c.sentIDs[request.ID.(types.JSONRPCIntID)] = true
   228  		// c.mtx.Unlock()
   229  		return nil
   230  	case <-ctx.Done():
   231  		return ctx.Err()
   232  	}
   233  }
   234  
   235  // Call enqueues a call request onto the Send queue. Requests are JSON encoded.
   236  func (c *WSClient) Call(ctx context.Context, method string, params map[string]interface{}) error {
   237  	request, err := types.MapToRequest(c.nextRequestID(), method, params)
   238  	if err != nil {
   239  		return err
   240  	}
   241  	return c.Send(ctx, request)
   242  }
   243  
   244  // CallWithArrayParams enqueues a call request onto the Send queue. Params are
   245  // in a form of array (e.g. []interface{}{"abcd"}). Requests are JSON encoded.
   246  func (c *WSClient) CallWithArrayParams(ctx context.Context, method string, params []interface{}) error {
   247  	request, err := types.ArrayToRequest(c.nextRequestID(), method, params)
   248  	if err != nil {
   249  		return err
   250  	}
   251  	return c.Send(ctx, request)
   252  }
   253  
   254  // Private methods
   255  
   256  func (c *WSClient) nextRequestID() types.JSONRPCIntID {
   257  	c.mtx.Lock()
   258  	id := c.nextReqID
   259  	c.nextReqID++
   260  	c.mtx.Unlock()
   261  	return types.JSONRPCIntID(id)
   262  }
   263  
   264  func (c *WSClient) dial() error {
   265  	dialer := &websocket.Dialer{
   266  		NetDial: c.Dialer,
   267  		Proxy:   http.ProxyFromEnvironment,
   268  	}
   269  	rHeader := http.Header{}
   270  	conn, _, err := dialer.Dial(c.protocol+"://"+c.Address+c.Endpoint, rHeader) //nolint:bodyclose
   271  	if err != nil {
   272  		return err
   273  	}
   274  	c.conn = conn
   275  	return nil
   276  }
   277  
   278  // reconnect tries to redial up to maxReconnectAttempts with exponential
   279  // backoff.
   280  func (c *WSClient) reconnect() error {
   281  	attempt := 0
   282  
   283  	c.mtx.Lock()
   284  	c.reconnecting = true
   285  	c.mtx.Unlock()
   286  	defer func() {
   287  		c.mtx.Lock()
   288  		c.reconnecting = false
   289  		c.mtx.Unlock()
   290  	}()
   291  
   292  	for {
   293  		jitter := time.Duration(tmrand.Float64() * float64(time.Second)) // 1s == (1e9 ns)
   294  		backoffDuration := jitter + ((1 << uint(attempt)) * time.Second)
   295  
   296  		c.Logger.Info("reconnecting", "attempt", attempt+1, "backoff_duration", backoffDuration)
   297  		time.Sleep(backoffDuration)
   298  
   299  		err := c.dial()
   300  		if err != nil {
   301  			c.Logger.Error("failed to redial", "err", err)
   302  		} else {
   303  			c.Logger.Info("reconnected")
   304  			if c.onReconnect != nil {
   305  				go c.onReconnect()
   306  			}
   307  			return nil
   308  		}
   309  
   310  		attempt++
   311  
   312  		if attempt > c.maxReconnectAttempts {
   313  			return fmt.Errorf("reached maximum reconnect attempts: %w", err)
   314  		}
   315  	}
   316  }
   317  
   318  func (c *WSClient) startReadWriteRoutines() {
   319  	c.wg.Add(2)
   320  	c.readRoutineQuit = make(chan struct{})
   321  	go c.readRoutine()
   322  	go c.writeRoutine()
   323  }
   324  
   325  func (c *WSClient) processBacklog() error {
   326  	select {
   327  	case request := <-c.backlog:
   328  		if c.writeWait > 0 {
   329  			if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeWait)); err != nil {
   330  				c.Logger.Error("failed to set write deadline", "err", err)
   331  			}
   332  		}
   333  		if err := c.conn.WriteJSON(request); err != nil {
   334  			c.Logger.Error("failed to resend request", "err", err)
   335  			c.reconnectAfter <- err
   336  			// requeue request
   337  			c.backlog <- request
   338  			return err
   339  		}
   340  		c.Logger.Info("resend a request", "req", request)
   341  	default:
   342  	}
   343  	return nil
   344  }
   345  
   346  func (c *WSClient) reconnectRoutine() {
   347  	for {
   348  		select {
   349  		case originalError := <-c.reconnectAfter:
   350  			// wait until writeRoutine and readRoutine finish
   351  			c.wg.Wait()
   352  			if err := c.reconnect(); err != nil {
   353  				c.Logger.Error("failed to reconnect", "err", err, "original_err", originalError)
   354  				if err = c.Stop(); err != nil {
   355  					c.Logger.Error("failed to stop conn", "error", err)
   356  				}
   357  
   358  				return
   359  			}
   360  			// drain reconnectAfter
   361  		LOOP:
   362  			for {
   363  				select {
   364  				case <-c.reconnectAfter:
   365  				default:
   366  					break LOOP
   367  				}
   368  			}
   369  			err := c.processBacklog()
   370  			if err == nil {
   371  				c.startReadWriteRoutines()
   372  			}
   373  
   374  		case <-c.Quit():
   375  			return
   376  		}
   377  	}
   378  }
   379  
   380  // The client ensures that there is at most one writer to a connection by
   381  // executing all writes from this goroutine.
   382  func (c *WSClient) writeRoutine() {
   383  	var ticker *time.Ticker
   384  	if c.pingPeriod > 0 {
   385  		// ticker with a predefined period
   386  		ticker = time.NewTicker(c.pingPeriod)
   387  	} else {
   388  		// ticker that never fires
   389  		ticker = &time.Ticker{C: make(<-chan time.Time)}
   390  	}
   391  
   392  	defer func() {
   393  		ticker.Stop()
   394  		c.conn.Close()
   395  		// err != nil {
   396  		// ignore error; it will trigger in tests
   397  		// likely because it's closing an already closed connection
   398  		// }
   399  		c.wg.Done()
   400  	}()
   401  
   402  	for {
   403  		select {
   404  		case request := <-c.send:
   405  			if c.writeWait > 0 {
   406  				if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeWait)); err != nil {
   407  					c.Logger.Error("failed to set write deadline", "err", err)
   408  				}
   409  			}
   410  			if err := c.conn.WriteJSON(request); err != nil {
   411  				c.Logger.Error("failed to send request", "err", err)
   412  				c.reconnectAfter <- err
   413  				// add request to the backlog, so we don't lose it
   414  				c.backlog <- request
   415  				return
   416  			}
   417  		case <-ticker.C:
   418  			if c.writeWait > 0 {
   419  				if err := c.conn.SetWriteDeadline(time.Now().Add(c.writeWait)); err != nil {
   420  					c.Logger.Error("failed to set write deadline", "err", err)
   421  				}
   422  			}
   423  			if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
   424  				c.Logger.Error("failed to write ping", "err", err)
   425  				c.reconnectAfter <- err
   426  				return
   427  			}
   428  			c.mtx.Lock()
   429  			c.sentLastPingAt = time.Now()
   430  			c.mtx.Unlock()
   431  			c.Logger.Debug("sent ping")
   432  		case <-c.readRoutineQuit:
   433  			return
   434  		case <-c.Quit():
   435  			if err := c.conn.WriteMessage(
   436  				websocket.CloseMessage,
   437  				websocket.FormatCloseMessage(websocket.CloseNormalClosure, ""),
   438  			); err != nil {
   439  				c.Logger.Error("failed to write message", "err", err)
   440  			}
   441  			return
   442  		}
   443  	}
   444  }
   445  
   446  // The client ensures that there is at most one reader to a connection by
   447  // executing all reads from this goroutine.
   448  func (c *WSClient) readRoutine() {
   449  	defer func() {
   450  		c.conn.Close()
   451  		// err != nil {
   452  		// ignore error; it will trigger in tests
   453  		// likely because it's closing an already closed connection
   454  		// }
   455  		c.wg.Done()
   456  	}()
   457  
   458  	c.conn.SetPongHandler(func(string) error {
   459  		// gather latency stats
   460  		c.mtx.RLock()
   461  		t := c.sentLastPingAt
   462  		c.mtx.RUnlock()
   463  		c.PingPongLatencyTimer.UpdateSince(t)
   464  
   465  		c.Logger.Debug("got pong")
   466  		return nil
   467  	})
   468  
   469  	for {
   470  		// reset deadline for every message type (control or data)
   471  		if c.readWait > 0 {
   472  			if err := c.conn.SetReadDeadline(time.Now().Add(c.readWait)); err != nil {
   473  				c.Logger.Error("failed to set read deadline", "err", err)
   474  			}
   475  		}
   476  		_, data, err := c.conn.ReadMessage()
   477  		if err != nil {
   478  			if !websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
   479  				return
   480  			}
   481  
   482  			c.Logger.Error("failed to read response", "err", err)
   483  			close(c.readRoutineQuit)
   484  			c.reconnectAfter <- err
   485  			return
   486  		}
   487  
   488  		var response types.RPCResponse
   489  		err = json.Unmarshal(data, &response)
   490  		if err != nil {
   491  			c.Logger.Error("failed to parse response", "err", err, "data", string(data))
   492  			continue
   493  		}
   494  
   495  		if err = validateResponseID(response.ID); err != nil {
   496  			c.Logger.Error("error in response ID", "id", response.ID, "err", err)
   497  			continue
   498  		}
   499  
   500  		// TODO: events resulting from /subscribe do not work with ->
   501  		// because they are implemented as responses with the subscribe request's
   502  		// ID. According to the spec, they should be notifications (requests
   503  		// without IDs).
   504  		// https://github.com/vipernet-xyz/tm/issues/2949
   505  		// c.mtx.Lock()
   506  		// if _, ok := c.sentIDs[response.ID.(types.JSONRPCIntID)]; !ok {
   507  		// 	c.Logger.Error("unsolicited response ID", "id", response.ID, "expected", c.sentIDs)
   508  		// 	c.mtx.Unlock()
   509  		// 	continue
   510  		// }
   511  		// delete(c.sentIDs, response.ID.(types.JSONRPCIntID))
   512  		// c.mtx.Unlock()
   513  		// Combine a non-blocking read on BaseService.Quit with a non-blocking write on ResponsesCh to avoid blocking
   514  		// c.wg.Wait() in c.Stop(). Note we rely on Quit being closed so that it sends unlimited Quit signals to stop
   515  		// both readRoutine and writeRoutine
   516  
   517  		c.Logger.Info("got response", "id", response.ID, "result", log.NewLazySprintf("%X", response.Result))
   518  
   519  		select {
   520  		case <-c.Quit():
   521  		case c.ResponsesCh <- response:
   522  		}
   523  	}
   524  }
   525  
   526  // Predefined methods
   527  
   528  // Subscribe to a query. Note the server must have a "subscribe" route
   529  // defined.
   530  func (c *WSClient) Subscribe(ctx context.Context, query string) error {
   531  	params := map[string]interface{}{"query": query}
   532  	return c.Call(ctx, "subscribe", params)
   533  }
   534  
   535  // Unsubscribe from a query. Note the server must have a "unsubscribe" route
   536  // defined.
   537  func (c *WSClient) Unsubscribe(ctx context.Context, query string) error {
   538  	params := map[string]interface{}{"query": query}
   539  	return c.Call(ctx, "unsubscribe", params)
   540  }
   541  
   542  // UnsubscribeAll from all. Note the server must have a "unsubscribe_all" route
   543  // defined.
   544  func (c *WSClient) UnsubscribeAll(ctx context.Context) error {
   545  	params := map[string]interface{}{}
   546  	return c.Call(ctx, "unsubscribe_all", params)
   547  }