github.com/gagliardetto/solana-go@v1.11.0/rpc/ws/client.go (about)

     1  // Copyright 2021 github.com/gagliardetto
     2  // This file has been modified by github.com/gagliardetto
     3  //
     4  // Copyright 2020 dfuse Platform Inc.
     5  //
     6  // Licensed under the Apache License, Version 2.0 (the "License");
     7  // you may not use this file except in compliance with the License.
     8  // You may obtain a copy of the License at
     9  //
    10  //      http://www.apache.org/licenses/LICENSE-2.0
    11  //
    12  // Unless required by applicable law or agreed to in writing, software
    13  // distributed under the License is distributed on an "AS IS" BASIS,
    14  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    15  // See the License for the specific language governing permissions and
    16  // limitations under the License.
    17  
    18  package ws
    19  
    20  import (
    21  	"context"
    22  	"fmt"
    23  	"io"
    24  	"net/http"
    25  	"strconv"
    26  	"sync"
    27  	"time"
    28  
    29  	"github.com/buger/jsonparser"
    30  	"github.com/gorilla/rpc/v2/json2"
    31  	"github.com/gorilla/websocket"
    32  	"go.uber.org/zap"
    33  )
    34  
    35  type result interface{}
    36  
    37  type Client struct {
    38  	rpcURL                  string
    39  	conn                    *websocket.Conn
    40  	connCtx                 context.Context
    41  	connCtxCancel           context.CancelFunc
    42  	lock                    sync.RWMutex
    43  	subscriptionByRequestID map[uint64]*Subscription
    44  	subscriptionByWSSubID   map[uint64]*Subscription
    45  	reconnectOnErr          bool
    46  }
    47  
    48  const (
    49  	// Time allowed to write a message to the peer.
    50  	writeWait = 10 * time.Second
    51  	// Time allowed to read the next pong message from the peer.
    52  	pongWait = 60 * time.Second
    53  	// Send pings to peer with this period. Must be less than pongWait.
    54  	pingPeriod = (pongWait * 9) / 10
    55  )
    56  
    57  // Connect creates a new websocket client connecting to the provided endpoint.
    58  func Connect(ctx context.Context, rpcEndpoint string) (c *Client, err error) {
    59  	return ConnectWithOptions(ctx, rpcEndpoint, nil)
    60  }
    61  
    62  // ConnectWithOptions creates a new websocket client connecting to the provided
    63  // endpoint with a http header if available The http header can be helpful to
    64  // pass basic authentication params as prescribed
    65  // ref https://github.com/gorilla/websocket/issues/209
    66  func ConnectWithOptions(ctx context.Context, rpcEndpoint string, opt *Options) (c *Client, err error) {
    67  	c = &Client{
    68  		rpcURL:                  rpcEndpoint,
    69  		subscriptionByRequestID: map[uint64]*Subscription{},
    70  		subscriptionByWSSubID:   map[uint64]*Subscription{},
    71  	}
    72  
    73  	dialer := &websocket.Dialer{
    74  		Proxy:             http.ProxyFromEnvironment,
    75  		HandshakeTimeout:  DefaultHandshakeTimeout,
    76  		EnableCompression: true,
    77  	}
    78  
    79  	if opt != nil && opt.HandshakeTimeout > 0 {
    80  		dialer.HandshakeTimeout = opt.HandshakeTimeout
    81  	}
    82  
    83  	var httpHeader http.Header = nil
    84  	if opt != nil && opt.HttpHeader != nil && len(opt.HttpHeader) > 0 {
    85  		httpHeader = opt.HttpHeader
    86  	}
    87  	var resp *http.Response
    88  	c.conn, resp, err = dialer.DialContext(ctx, rpcEndpoint, httpHeader)
    89  	if err != nil {
    90  		if resp != nil {
    91  			body, _ := io.ReadAll(resp.Body)
    92  			err = fmt.Errorf("new ws client: dial: %w, status: %s, body: %q", err, resp.Status, string(body))
    93  		} else {
    94  			err = fmt.Errorf("new ws client: dial: %w", err)
    95  		}
    96  		return nil, err
    97  	}
    98  
    99  	c.connCtx, c.connCtxCancel = context.WithCancel(context.Background())
   100  	go func() {
   101  		c.conn.SetReadDeadline(time.Now().Add(pongWait))
   102  		c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil })
   103  		ticker := time.NewTicker(pingPeriod)
   104  		for {
   105  			select {
   106  			case <-c.connCtx.Done():
   107  				return
   108  			case <-ticker.C:
   109  				c.sendPing()
   110  			}
   111  		}
   112  	}()
   113  	go c.receiveMessages()
   114  	return c, nil
   115  }
   116  
   117  func (c *Client) sendPing() {
   118  	c.lock.Lock()
   119  	defer c.lock.Unlock()
   120  
   121  	c.conn.SetWriteDeadline(time.Now().Add(writeWait))
   122  	if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil {
   123  		return
   124  	}
   125  }
   126  
   127  func (c *Client) Close() {
   128  	c.lock.Lock()
   129  	defer c.lock.Unlock()
   130  	c.connCtxCancel()
   131  	c.conn.Close()
   132  }
   133  
   134  func (c *Client) receiveMessages() {
   135  	for {
   136  		select {
   137  		case <-c.connCtx.Done():
   138  			return
   139  		default:
   140  			_, message, err := c.conn.ReadMessage()
   141  			if err != nil {
   142  				c.closeAllSubscription(err)
   143  				return
   144  			}
   145  			c.handleMessage(message)
   146  		}
   147  	}
   148  }
   149  
   150  // GetUint64 returns the value retrieved by `Get`, cast to a uint64 if possible.
   151  // If key data type do not match, it will return an error.
   152  func getUint64(data []byte, keys ...string) (val uint64, err error) {
   153  	v, t, _, e := jsonparser.Get(data, keys...)
   154  	if e != nil {
   155  		return 0, e
   156  	}
   157  	if t != jsonparser.Number {
   158  		return 0, fmt.Errorf("Value is not a number: %s", string(v))
   159  	}
   160  	return strconv.ParseUint(string(v), 10, 64)
   161  }
   162  
   163  func getUint64WithOk(data []byte, path ...string) (uint64, bool) {
   164  	val, err := getUint64(data, path...)
   165  	if err == nil {
   166  		return val, true
   167  	}
   168  	return 0, false
   169  }
   170  
   171  func (c *Client) handleMessage(message []byte) {
   172  	// when receiving message with id. the result will be a subscription number.
   173  	// that number will be associated to all future message destine to this request
   174  
   175  	requestID, ok := getUint64WithOk(message, "id")
   176  	if ok {
   177  		subID, _ := getUint64WithOk(message, "result")
   178  		c.handleNewSubscriptionMessage(requestID, subID)
   179  		return
   180  	}
   181  
   182  	subID, _ := getUint64WithOk(message, "params", "subscription")
   183  	c.handleSubscriptionMessage(subID, message)
   184  }
   185  
   186  func (c *Client) handleNewSubscriptionMessage(requestID, subID uint64) {
   187  	c.lock.Lock()
   188  	defer c.lock.Unlock()
   189  
   190  	if traceEnabled {
   191  		zlog.Debug("received new subscription message",
   192  			zap.Uint64("message_id", requestID),
   193  			zap.Uint64("subscription_id", subID),
   194  		)
   195  	}
   196  
   197  	callBack, found := c.subscriptionByRequestID[requestID]
   198  	if !found {
   199  		zlog.Error("cannot find websocket message handler for a new stream.... this should not happen",
   200  			zap.Uint64("request_id", requestID),
   201  			zap.Uint64("subscription_id", subID),
   202  		)
   203  		return
   204  	}
   205  	callBack.subID = subID
   206  	c.subscriptionByWSSubID[subID] = callBack
   207  
   208  	zlog.Debug("registered ws subscription",
   209  		zap.Uint64("subscription_id", subID),
   210  		zap.Uint64("request_id", requestID),
   211  		zap.Int("subscription_count", len(c.subscriptionByWSSubID)),
   212  	)
   213  	return
   214  }
   215  
   216  func (c *Client) handleSubscriptionMessage(subID uint64, message []byte) {
   217  	if traceEnabled {
   218  		zlog.Debug("received subscription message",
   219  			zap.Uint64("subscription_id", subID),
   220  		)
   221  	}
   222  
   223  	c.lock.RLock()
   224  	sub, found := c.subscriptionByWSSubID[subID]
   225  	c.lock.RUnlock()
   226  	if !found {
   227  		zlog.Warn("unable to find subscription for ws message", zap.Uint64("subscription_id", subID))
   228  		return
   229  	}
   230  
   231  	// Decode the message using the subscription-provided decoderFunc.
   232  	result, err := sub.decoderFunc(message)
   233  	if err != nil {
   234  		fmt.Println("*****************************")
   235  		c.closeSubscription(sub.req.ID, fmt.Errorf("unable to decode client response: %w", err))
   236  		return
   237  	}
   238  
   239  	// this cannot be blocking or else
   240  	// we  will no read any other message
   241  	if len(sub.stream) >= cap(sub.stream) {
   242  		zlog.Warn("closing ws client subscription... not consuming fast en ought",
   243  			zap.Uint64("request_id", sub.req.ID),
   244  		)
   245  		c.closeSubscription(sub.req.ID, fmt.Errorf("reached channel max capacity %d", len(sub.stream)))
   246  		return
   247  	}
   248  
   249  	if !sub.closed {
   250  		sub.stream <- result
   251  	}
   252  	return
   253  }
   254  
   255  func (c *Client) closeAllSubscription(err error) {
   256  	c.lock.Lock()
   257  	defer c.lock.Unlock()
   258  
   259  	for _, sub := range c.subscriptionByRequestID {
   260  		sub.err <- err
   261  	}
   262  
   263  	c.subscriptionByRequestID = map[uint64]*Subscription{}
   264  	c.subscriptionByWSSubID = map[uint64]*Subscription{}
   265  }
   266  
   267  func (c *Client) closeSubscription(reqID uint64, err error) {
   268  	c.lock.Lock()
   269  	defer c.lock.Unlock()
   270  
   271  	sub, found := c.subscriptionByRequestID[reqID]
   272  	if !found {
   273  		return
   274  	}
   275  
   276  	sub.err <- err
   277  
   278  	err = c.unsubscribe(sub.subID, sub.unsubscribeMethod)
   279  	if err != nil {
   280  		zlog.Warn("unable to send rpc unsubscribe call",
   281  			zap.Error(err),
   282  		)
   283  	}
   284  
   285  	delete(c.subscriptionByRequestID, sub.req.ID)
   286  	delete(c.subscriptionByWSSubID, sub.subID)
   287  }
   288  
   289  func (c *Client) unsubscribe(subID uint64, method string) error {
   290  	req := newRequest([]interface{}{subID}, method, nil)
   291  	data, err := req.encode()
   292  	if err != nil {
   293  		return fmt.Errorf("unable to encode unsubscription message for subID %d and method %s", subID, method)
   294  	}
   295  
   296  	c.conn.SetWriteDeadline(time.Now().Add(writeWait))
   297  	err = c.conn.WriteMessage(websocket.TextMessage, data)
   298  	if err != nil {
   299  		return fmt.Errorf("unable to send unsubscription message for subID %d and method %s", subID, method)
   300  	}
   301  	return nil
   302  }
   303  
   304  func (c *Client) subscribe(
   305  	params []interface{},
   306  	conf map[string]interface{},
   307  	subscriptionMethod string,
   308  	unsubscribeMethod string,
   309  	decoderFunc decoderFunc,
   310  ) (*Subscription, error) {
   311  	c.lock.Lock()
   312  	defer c.lock.Unlock()
   313  
   314  	req := newRequest(params, subscriptionMethod, conf)
   315  	data, err := req.encode()
   316  	if err != nil {
   317  		return nil, fmt.Errorf("subscribe: unable to encode subsciption request: %w", err)
   318  	}
   319  
   320  	sub := newSubscription(
   321  		req,
   322  		func(err error) {
   323  			c.closeSubscription(req.ID, err)
   324  		},
   325  		unsubscribeMethod,
   326  		decoderFunc,
   327  	)
   328  
   329  	c.subscriptionByRequestID[req.ID] = sub
   330  	zlog.Info("added new subscription to websocket client", zap.Int("count", len(c.subscriptionByRequestID)))
   331  
   332  	zlog.Debug("writing data to conn", zap.String("data", string(data)))
   333  	c.conn.SetWriteDeadline(time.Now().Add(writeWait))
   334  	err = c.conn.WriteMessage(websocket.TextMessage, data)
   335  	if err != nil {
   336  		return nil, fmt.Errorf("unable to write request: %w", err)
   337  	}
   338  
   339  	return sub, nil
   340  }
   341  
   342  func decodeResponseFromReader(r io.Reader, reply interface{}) (err error) {
   343  	var c *response
   344  	if err := json.NewDecoder(r).Decode(&c); err != nil {
   345  		return err
   346  	}
   347  
   348  	if c.Error != nil {
   349  		jsonErr := &json2.Error{}
   350  		if err := json.Unmarshal(*c.Error, jsonErr); err != nil {
   351  			return &json2.Error{
   352  				Code:    json2.E_SERVER,
   353  				Message: string(*c.Error),
   354  			}
   355  		}
   356  		return jsonErr
   357  	}
   358  
   359  	if c.Params == nil {
   360  		return json2.ErrNullResult
   361  	}
   362  
   363  	return json.Unmarshal(*c.Params.Result, &reply)
   364  }
   365  
   366  func decodeResponseFromMessage(r []byte, reply interface{}) (err error) {
   367  	var c *response
   368  	if err := json.Unmarshal(r, &c); err != nil {
   369  		return err
   370  	}
   371  
   372  	if c.Error != nil {
   373  		jsonErr := &json2.Error{}
   374  		if err := json.Unmarshal(*c.Error, jsonErr); err != nil {
   375  			return &json2.Error{
   376  				Code:    json2.E_SERVER,
   377  				Message: string(*c.Error),
   378  			}
   379  		}
   380  		return jsonErr
   381  	}
   382  
   383  	if c.Params == nil {
   384  		return json2.ErrNullResult
   385  	}
   386  
   387  	return json.Unmarshal(*c.Params.Result, &reply)
   388  }