github.com/gnolang/gno@v0.0.0-20240520182011-228e9d0192ce/tm2/pkg/bft/rpc/lib/client/ws/client.go (about) 1 package ws 2 3 import ( 4 "context" 5 "encoding/json" 6 "fmt" 7 "hash/fnv" 8 "log/slog" 9 "sync" 10 11 types "github.com/gnolang/gno/tm2/pkg/bft/rpc/lib/types" 12 "github.com/gnolang/gno/tm2/pkg/errors" 13 "github.com/gnolang/gno/tm2/pkg/log" 14 "github.com/gorilla/websocket" 15 ) 16 17 var ( 18 ErrTimedOut = errors.New("context timed out") 19 ErrRequestResponseIDMismatch = errors.New("ws request / response ID mismatch") 20 ErrInvalidBatchResponse = errors.New("invalid ws batch response size") 21 ) 22 23 type responseCh chan<- types.RPCResponses 24 25 // Client is a WebSocket client implementation 26 type Client struct { 27 ctx context.Context 28 cancelCauseFn context.CancelCauseFunc 29 30 conn *websocket.Conn 31 32 logger *slog.Logger 33 backlog chan any // Either a single RPC request, or a batch of RPC requests 34 35 requestMap map[string]responseCh 36 requestMapMux sync.Mutex 37 } 38 39 // NewClient initializes and creates a new WS RPC client 40 func NewClient(rpcURL string, opts ...Option) (*Client, error) { 41 // Dial the RPC URL 42 conn, _, err := websocket.DefaultDialer.Dial(rpcURL, nil) 43 if err != nil { 44 return nil, fmt.Errorf("unable to dial RPC, %w", err) 45 } 46 47 c := &Client{ 48 conn: conn, 49 requestMap: make(map[string]responseCh), 50 backlog: make(chan any, 1), 51 logger: log.NewNoopLogger(), 52 } 53 54 ctx, cancelFn := context.WithCancelCause(context.Background()) 55 c.ctx = ctx 56 c.cancelCauseFn = cancelFn 57 58 // Apply the options 59 for _, opt := range opts { 60 opt(c) 61 } 62 63 go c.runReadRoutine(ctx) 64 go c.runWriteRoutine(ctx) 65 66 return c, nil 67 } 68 69 // SendRequest sends a single RPC request to the server 70 func (c *Client) SendRequest(ctx context.Context, request types.RPCRequest) (*types.RPCResponse, error) { 71 // Create the response channel for the pipeline 72 responseCh := make(chan types.RPCResponses, 1) 73 74 // Generate a unique request ID hash 75 requestHash := generateIDHash(request.ID.String()) 76 77 c.requestMapMux.Lock() 78 c.requestMap[requestHash] = responseCh 79 c.requestMapMux.Unlock() 80 81 // Pipe the request to the backlog 82 select { 83 case <-ctx.Done(): 84 return nil, ErrTimedOut 85 case <-c.ctx.Done(): 86 return nil, context.Cause(c.ctx) 87 case c.backlog <- request: 88 } 89 90 // Wait for the response 91 select { 92 case <-ctx.Done(): 93 return nil, ErrTimedOut 94 case <-c.ctx.Done(): 95 return nil, context.Cause(c.ctx) 96 case response := <-responseCh: 97 // Make sure the ID matches 98 if response[0].ID != request.ID { 99 return nil, ErrRequestResponseIDMismatch 100 } 101 102 return &response[0], nil 103 } 104 } 105 106 // SendBatch sends a batch of RPC requests to the server 107 func (c *Client) SendBatch(ctx context.Context, requests types.RPCRequests) (types.RPCResponses, error) { 108 // Create the response channel for the pipeline 109 responseCh := make(chan types.RPCResponses, 1) 110 111 // Generate a unique request ID hash 112 requestIDs := make([]string, 0, len(requests)) 113 114 for _, request := range requests { 115 requestIDs = append(requestIDs, request.ID.String()) 116 } 117 118 requestHash := generateIDHash(requestIDs...) 119 120 c.requestMapMux.Lock() 121 c.requestMap[requestHash] = responseCh 122 c.requestMapMux.Unlock() 123 124 // Pipe the request to the backlog 125 select { 126 case <-ctx.Done(): 127 return nil, ErrTimedOut 128 case <-c.ctx.Done(): 129 return nil, context.Cause(c.ctx) 130 case c.backlog <- requests: 131 } 132 133 // Wait for the response 134 select { 135 case <-ctx.Done(): 136 return nil, ErrTimedOut 137 case <-c.ctx.Done(): 138 return nil, context.Cause(c.ctx) 139 case responses := <-responseCh: 140 // Make sure the length matches 141 if len(responses) != len(requests) { 142 return nil, ErrInvalidBatchResponse 143 } 144 145 // Make sure the IDs match 146 for index, response := range responses { 147 if requests[index].ID != response.ID { 148 return nil, ErrRequestResponseIDMismatch 149 } 150 } 151 152 return responses, nil 153 } 154 } 155 156 // generateIDHash generates a unique hash from the given IDs 157 func generateIDHash(ids ...string) string { 158 hash := fnv.New128() 159 160 for _, id := range ids { 161 hash.Write([]byte(id)) 162 } 163 164 return string(hash.Sum(nil)) 165 } 166 167 // runWriteRoutine runs the client -> server write routine 168 func (c *Client) runWriteRoutine(ctx context.Context) { 169 for { 170 select { 171 case <-ctx.Done(): 172 c.logger.Debug("write context finished") 173 174 return 175 case item := <-c.backlog: 176 // Write the JSON request to the server 177 if err := c.conn.WriteJSON(item); err != nil { 178 c.logger.Error("unable to send request", "err", err) 179 180 continue 181 } 182 183 c.logger.Debug("successfully sent request", "request", item) 184 } 185 } 186 } 187 188 // runReadRoutine runs the client <- server read routine 189 func (c *Client) runReadRoutine(ctx context.Context) { 190 for { 191 select { 192 case <-ctx.Done(): 193 c.logger.Debug("read context finished") 194 195 return 196 default: 197 } 198 199 // Read the message from the active connection 200 _, data, err := c.conn.ReadMessage() 201 if err != nil { 202 if websocket.IsUnexpectedCloseError(err, websocket.CloseNormalClosure) { 203 c.logger.Error("failed to read response", "err", err) 204 205 // Server dropped the connection, stop the client 206 if err = c.closeWithCause( 207 fmt.Errorf("server closed connection, %w", err), 208 ); err != nil { 209 c.logger.Error("unable to gracefully close client", "err", err) 210 } 211 212 return 213 } 214 215 continue 216 } 217 218 var ( 219 responses types.RPCResponses 220 responseHash string 221 ) 222 223 // Try to unmarshal as a batch of responses first 224 if err := json.Unmarshal(data, &responses); err != nil { 225 // Try to unmarshal as a single response 226 var response types.RPCResponse 227 228 if err := json.Unmarshal(data, &response); err != nil { 229 c.logger.Error("failed to parse response", "err", err, "data", string(data)) 230 231 continue 232 } 233 234 // This is a single response, generate the unique ID 235 responseHash = generateIDHash(response.ID.String()) 236 responses = types.RPCResponses{response} 237 } else { 238 // This is a batch response, generate the unique ID 239 // from the combined IDs 240 ids := make([]string, 0, len(responses)) 241 242 for _, response := range responses { 243 ids = append(ids, response.ID.String()) 244 } 245 246 responseHash = generateIDHash(ids...) 247 } 248 249 // Grab the response channel 250 c.requestMapMux.Lock() 251 ch := c.requestMap[responseHash] 252 if ch == nil { 253 c.requestMapMux.Unlock() 254 c.logger.Error("response listener not set", "hash", responseHash, "responses", responses) 255 256 continue 257 } 258 259 // Clear the entry for this ID 260 delete(c.requestMap, responseHash) 261 c.requestMapMux.Unlock() 262 263 c.logger.Debug("received response", "hash", responseHash) 264 265 // Alert the listener of the response 266 select { 267 case ch <- responses: 268 default: 269 c.logger.Warn("response listener timed out", "hash", responseHash) 270 } 271 } 272 } 273 274 // Close closes the WS client 275 func (c *Client) Close() error { 276 return c.closeWithCause(nil) 277 } 278 279 // closeWithCause closes the client (and any open connection) 280 // with the given cause 281 func (c *Client) closeWithCause(err error) error { 282 c.cancelCauseFn(err) 283 284 return c.conn.Close() 285 }