github.com/gnolang/gno@v0.0.0-20240520182011-228e9d0192ce/tm2/pkg/bft/rpc/lib/client/ws/client.go (about)

     1  package ws
     2  
     3  import (
     4  	"context"
     5  	"encoding/json"
     6  	"fmt"
     7  	"hash/fnv"
     8  	"log/slog"
     9  	"sync"
    10  
    11  	types "github.com/gnolang/gno/tm2/pkg/bft/rpc/lib/types"
    12  	"github.com/gnolang/gno/tm2/pkg/errors"
    13  	"github.com/gnolang/gno/tm2/pkg/log"
    14  	"github.com/gorilla/websocket"
    15  )
    16  
    17  var (
    18  	ErrTimedOut                  = errors.New("context timed out")
    19  	ErrRequestResponseIDMismatch = errors.New("ws request / response ID mismatch")
    20  	ErrInvalidBatchResponse      = errors.New("invalid ws batch response size")
    21  )
    22  
    23  type responseCh chan<- types.RPCResponses
    24  
    25  // Client is a WebSocket client implementation
    26  type Client struct {
    27  	ctx           context.Context
    28  	cancelCauseFn context.CancelCauseFunc
    29  
    30  	conn *websocket.Conn
    31  
    32  	logger  *slog.Logger
    33  	backlog chan any // Either a single RPC request, or a batch of RPC requests
    34  
    35  	requestMap    map[string]responseCh
    36  	requestMapMux sync.Mutex
    37  }
    38  
    39  // NewClient initializes and creates a new WS RPC client
    40  func NewClient(rpcURL string, opts ...Option) (*Client, error) {
    41  	// Dial the RPC URL
    42  	conn, _, err := websocket.DefaultDialer.Dial(rpcURL, nil)
    43  	if err != nil {
    44  		return nil, fmt.Errorf("unable to dial RPC, %w", err)
    45  	}
    46  
    47  	c := &Client{
    48  		conn:       conn,
    49  		requestMap: make(map[string]responseCh),
    50  		backlog:    make(chan any, 1),
    51  		logger:     log.NewNoopLogger(),
    52  	}
    53  
    54  	ctx, cancelFn := context.WithCancelCause(context.Background())
    55  	c.ctx = ctx
    56  	c.cancelCauseFn = cancelFn
    57  
    58  	// Apply the options
    59  	for _, opt := range opts {
    60  		opt(c)
    61  	}
    62  
    63  	go c.runReadRoutine(ctx)
    64  	go c.runWriteRoutine(ctx)
    65  
    66  	return c, nil
    67  }
    68  
    69  // SendRequest sends a single RPC request to the server
    70  func (c *Client) SendRequest(ctx context.Context, request types.RPCRequest) (*types.RPCResponse, error) {
    71  	// Create the response channel for the pipeline
    72  	responseCh := make(chan types.RPCResponses, 1)
    73  
    74  	// Generate a unique request ID hash
    75  	requestHash := generateIDHash(request.ID.String())
    76  
    77  	c.requestMapMux.Lock()
    78  	c.requestMap[requestHash] = responseCh
    79  	c.requestMapMux.Unlock()
    80  
    81  	// Pipe the request to the backlog
    82  	select {
    83  	case <-ctx.Done():
    84  		return nil, ErrTimedOut
    85  	case <-c.ctx.Done():
    86  		return nil, context.Cause(c.ctx)
    87  	case c.backlog <- request:
    88  	}
    89  
    90  	// Wait for the response
    91  	select {
    92  	case <-ctx.Done():
    93  		return nil, ErrTimedOut
    94  	case <-c.ctx.Done():
    95  		return nil, context.Cause(c.ctx)
    96  	case response := <-responseCh:
    97  		// Make sure the ID matches
    98  		if response[0].ID != request.ID {
    99  			return nil, ErrRequestResponseIDMismatch
   100  		}
   101  
   102  		return &response[0], nil
   103  	}
   104  }
   105  
   106  // SendBatch sends a batch of RPC requests to the server
   107  func (c *Client) SendBatch(ctx context.Context, requests types.RPCRequests) (types.RPCResponses, error) {
   108  	// Create the response channel for the pipeline
   109  	responseCh := make(chan types.RPCResponses, 1)
   110  
   111  	// Generate a unique request ID hash
   112  	requestIDs := make([]string, 0, len(requests))
   113  
   114  	for _, request := range requests {
   115  		requestIDs = append(requestIDs, request.ID.String())
   116  	}
   117  
   118  	requestHash := generateIDHash(requestIDs...)
   119  
   120  	c.requestMapMux.Lock()
   121  	c.requestMap[requestHash] = responseCh
   122  	c.requestMapMux.Unlock()
   123  
   124  	// Pipe the request to the backlog
   125  	select {
   126  	case <-ctx.Done():
   127  		return nil, ErrTimedOut
   128  	case <-c.ctx.Done():
   129  		return nil, context.Cause(c.ctx)
   130  	case c.backlog <- requests:
   131  	}
   132  
   133  	// Wait for the response
   134  	select {
   135  	case <-ctx.Done():
   136  		return nil, ErrTimedOut
   137  	case <-c.ctx.Done():
   138  		return nil, context.Cause(c.ctx)
   139  	case responses := <-responseCh:
   140  		// Make sure the length matches
   141  		if len(responses) != len(requests) {
   142  			return nil, ErrInvalidBatchResponse
   143  		}
   144  
   145  		// Make sure the IDs match
   146  		for index, response := range responses {
   147  			if requests[index].ID != response.ID {
   148  				return nil, ErrRequestResponseIDMismatch
   149  			}
   150  		}
   151  
   152  		return responses, nil
   153  	}
   154  }
   155  
   156  // generateIDHash generates a unique hash from the given IDs
   157  func generateIDHash(ids ...string) string {
   158  	hash := fnv.New128()
   159  
   160  	for _, id := range ids {
   161  		hash.Write([]byte(id))
   162  	}
   163  
   164  	return string(hash.Sum(nil))
   165  }
   166  
   167  // runWriteRoutine runs the client -> server write routine
   168  func (c *Client) runWriteRoutine(ctx context.Context) {
   169  	for {
   170  		select {
   171  		case <-ctx.Done():
   172  			c.logger.Debug("write context finished")
   173  
   174  			return
   175  		case item := <-c.backlog:
   176  			// Write the JSON request to the server
   177  			if err := c.conn.WriteJSON(item); err != nil {
   178  				c.logger.Error("unable to send request", "err", err)
   179  
   180  				continue
   181  			}
   182  
   183  			c.logger.Debug("successfully sent request", "request", item)
   184  		}
   185  	}
   186  }
   187  
   188  // runReadRoutine runs the client <- server read routine
   189  func (c *Client) runReadRoutine(ctx context.Context) {
   190  	for {
   191  		select {
   192  		case <-ctx.Done():
   193  			c.logger.Debug("read context finished")
   194  
   195  			return
   196  		default:
   197  		}
   198  
   199  		// Read the message from the active connection
   200  		_, data, err := c.conn.ReadMessage()
   201  		if err != nil {
   202  			if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) {
   203  				c.logger.Error("failed to read response", "err", err)
   204  
   205  				// Server dropped the connection, stop the client
   206  				if err = c.closeWithCause(
   207  					fmt.Errorf("server closed connection, %w", err),
   208  				); err != nil {
   209  					c.logger.Error("unable to gracefully close client", "err", err)
   210  				}
   211  
   212  				return
   213  			}
   214  
   215  			continue
   216  		}
   217  
   218  		var (
   219  			responses    types.RPCResponses
   220  			responseHash string
   221  		)
   222  
   223  		// Try to unmarshal as a batch of responses first
   224  		if err := json.Unmarshal(data, &responses); err != nil {
   225  			// Try to unmarshal as a single response
   226  			var response types.RPCResponse
   227  
   228  			if err := json.Unmarshal(data, &response); err != nil {
   229  				c.logger.Error("failed to parse response", "err", err, "data", string(data))
   230  
   231  				continue
   232  			}
   233  
   234  			// This is a single response, generate the unique ID
   235  			responseHash = generateIDHash(response.ID.String())
   236  			responses = types.RPCResponses{response}
   237  		} else {
   238  			// This is a batch response, generate the unique ID
   239  			// from the combined IDs
   240  			ids := make([]string, 0, len(responses))
   241  
   242  			for _, response := range responses {
   243  				ids = append(ids, response.ID.String())
   244  			}
   245  
   246  			responseHash = generateIDHash(ids...)
   247  		}
   248  
   249  		// Grab the response channel
   250  		c.requestMapMux.Lock()
   251  		ch := c.requestMap[responseHash]
   252  		if ch == nil {
   253  			c.requestMapMux.Unlock()
   254  			c.logger.Error("response listener not set", "hash", responseHash, "responses", responses)
   255  
   256  			continue
   257  		}
   258  
   259  		// Clear the entry for this ID
   260  		delete(c.requestMap, responseHash)
   261  		c.requestMapMux.Unlock()
   262  
   263  		c.logger.Debug("received response", "hash", responseHash)
   264  
   265  		// Alert the listener of the response
   266  		select {
   267  		case ch <- responses:
   268  		default:
   269  			c.logger.Warn("response listener timed out", "hash", responseHash)
   270  		}
   271  	}
   272  }
   273  
   274  // Close closes the WS client
   275  func (c *Client) Close() error {
   276  	return c.closeWithCause(nil)
   277  }
   278  
   279  // closeWithCause closes the client (and any open connection)
   280  // with the given cause
   281  func (c *Client) closeWithCause(err error) error {
   282  	c.cancelCauseFn(err)
   283  
   284  	return c.conn.Close()
   285  }