github.com/gagliardetto/solana-go@v1.11.0/rpc/ws/client.go (about) 1 // Copyright 2021 github.com/gagliardetto 2 // This file has been modified by github.com/gagliardetto 3 // 4 // Copyright 2020 dfuse Platform Inc. 5 // 6 // Licensed under the Apache License, Version 2.0 (the "License"); 7 // you may not use this file except in compliance with the License. 8 // You may obtain a copy of the License at 9 // 10 // http://www.apache.org/licenses/LICENSE-2.0 11 // 12 // Unless required by applicable law or agreed to in writing, software 13 // distributed under the License is distributed on an "AS IS" BASIS, 14 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 15 // See the License for the specific language governing permissions and 16 // limitations under the License. 17 18 package ws 19 20 import ( 21 "context" 22 "fmt" 23 "io" 24 "net/http" 25 "strconv" 26 "sync" 27 "time" 28 29 "github.com/buger/jsonparser" 30 "github.com/gorilla/rpc/v2/json2" 31 "github.com/gorilla/websocket" 32 "go.uber.org/zap" 33 ) 34 35 type result interface{} 36 37 type Client struct { 38 rpcURL string 39 conn *websocket.Conn 40 connCtx context.Context 41 connCtxCancel context.CancelFunc 42 lock sync.RWMutex 43 subscriptionByRequestID map[uint64]*Subscription 44 subscriptionByWSSubID map[uint64]*Subscription 45 reconnectOnErr bool 46 } 47 48 const ( 49 // Time allowed to write a message to the peer. 50 writeWait = 10 * time.Second 51 // Time allowed to read the next pong message from the peer. 52 pongWait = 60 * time.Second 53 // Send pings to peer with this period. Must be less than pongWait. 54 pingPeriod = (pongWait * 9) / 10 55 ) 56 57 // Connect creates a new websocket client connecting to the provided endpoint. 58 func Connect(ctx context.Context, rpcEndpoint string) (c *Client, err error) { 59 return ConnectWithOptions(ctx, rpcEndpoint, nil) 60 } 61 62 // ConnectWithOptions creates a new websocket client connecting to the provided 63 // endpoint with a http header if available The http header can be helpful to 64 // pass basic authentication params as prescribed 65 // ref https://github.com/gorilla/websocket/issues/209 66 func ConnectWithOptions(ctx context.Context, rpcEndpoint string, opt *Options) (c *Client, err error) { 67 c = &Client{ 68 rpcURL: rpcEndpoint, 69 subscriptionByRequestID: map[uint64]*Subscription{}, 70 subscriptionByWSSubID: map[uint64]*Subscription{}, 71 } 72 73 dialer := &websocket.Dialer{ 74 Proxy: http.ProxyFromEnvironment, 75 HandshakeTimeout: DefaultHandshakeTimeout, 76 EnableCompression: true, 77 } 78 79 if opt != nil && opt.HandshakeTimeout > 0 { 80 dialer.HandshakeTimeout = opt.HandshakeTimeout 81 } 82 83 var httpHeader http.Header = nil 84 if opt != nil && opt.HttpHeader != nil && len(opt.HttpHeader) > 0 { 85 httpHeader = opt.HttpHeader 86 } 87 var resp *http.Response 88 c.conn, resp, err = dialer.DialContext(ctx, rpcEndpoint, httpHeader) 89 if err != nil { 90 if resp != nil { 91 body, _ := io.ReadAll(resp.Body) 92 err = fmt.Errorf("new ws client: dial: %w, status: %s, body: %q", err, resp.Status, string(body)) 93 } else { 94 err = fmt.Errorf("new ws client: dial: %w", err) 95 } 96 return nil, err 97 } 98 99 c.connCtx, c.connCtxCancel = context.WithCancel(context.Background()) 100 go func() { 101 c.conn.SetReadDeadline(time.Now().Add(pongWait)) 102 c.conn.SetPongHandler(func(string) error { c.conn.SetReadDeadline(time.Now().Add(pongWait)); return nil }) 103 ticker := time.NewTicker(pingPeriod) 104 for { 105 select { 106 case <-c.connCtx.Done(): 107 return 108 case <-ticker.C: 109 c.sendPing() 110 } 111 } 112 }() 113 go c.receiveMessages() 114 return c, nil 115 } 116 117 func (c *Client) sendPing() { 118 c.lock.Lock() 119 defer c.lock.Unlock() 120 121 c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 122 if err := c.conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { 123 return 124 } 125 } 126 127 func (c *Client) Close() { 128 c.lock.Lock() 129 defer c.lock.Unlock() 130 c.connCtxCancel() 131 c.conn.Close() 132 } 133 134 func (c *Client) receiveMessages() { 135 for { 136 select { 137 case <-c.connCtx.Done(): 138 return 139 default: 140 _, message, err := c.conn.ReadMessage() 141 if err != nil { 142 c.closeAllSubscription(err) 143 return 144 } 145 c.handleMessage(message) 146 } 147 } 148 } 149 150 // GetUint64 returns the value retrieved by `Get`, cast to a uint64 if possible. 151 // If key data type do not match, it will return an error. 152 func getUint64(data []byte, keys ...string) (val uint64, err error) { 153 v, t, _, e := jsonparser.Get(data, keys...) 154 if e != nil { 155 return 0, e 156 } 157 if t != jsonparser.Number { 158 return 0, fmt.Errorf("Value is not a number: %s", string(v)) 159 } 160 return strconv.ParseUint(string(v), 10, 64) 161 } 162 163 func getUint64WithOk(data []byte, path ...string) (uint64, bool) { 164 val, err := getUint64(data, path...) 165 if err == nil { 166 return val, true 167 } 168 return 0, false 169 } 170 171 func (c *Client) handleMessage(message []byte) { 172 // when receiving message with id. the result will be a subscription number. 173 // that number will be associated to all future message destine to this request 174 175 requestID, ok := getUint64WithOk(message, "id") 176 if ok { 177 subID, _ := getUint64WithOk(message, "result") 178 c.handleNewSubscriptionMessage(requestID, subID) 179 return 180 } 181 182 subID, _ := getUint64WithOk(message, "params", "subscription") 183 c.handleSubscriptionMessage(subID, message) 184 } 185 186 func (c *Client) handleNewSubscriptionMessage(requestID, subID uint64) { 187 c.lock.Lock() 188 defer c.lock.Unlock() 189 190 if traceEnabled { 191 zlog.Debug("received new subscription message", 192 zap.Uint64("message_id", requestID), 193 zap.Uint64("subscription_id", subID), 194 ) 195 } 196 197 callBack, found := c.subscriptionByRequestID[requestID] 198 if !found { 199 zlog.Error("cannot find websocket message handler for a new stream.... this should not happen", 200 zap.Uint64("request_id", requestID), 201 zap.Uint64("subscription_id", subID), 202 ) 203 return 204 } 205 callBack.subID = subID 206 c.subscriptionByWSSubID[subID] = callBack 207 208 zlog.Debug("registered ws subscription", 209 zap.Uint64("subscription_id", subID), 210 zap.Uint64("request_id", requestID), 211 zap.Int("subscription_count", len(c.subscriptionByWSSubID)), 212 ) 213 return 214 } 215 216 func (c *Client) handleSubscriptionMessage(subID uint64, message []byte) { 217 if traceEnabled { 218 zlog.Debug("received subscription message", 219 zap.Uint64("subscription_id", subID), 220 ) 221 } 222 223 c.lock.RLock() 224 sub, found := c.subscriptionByWSSubID[subID] 225 c.lock.RUnlock() 226 if !found { 227 zlog.Warn("unable to find subscription for ws message", zap.Uint64("subscription_id", subID)) 228 return 229 } 230 231 // Decode the message using the subscription-provided decoderFunc. 232 result, err := sub.decoderFunc(message) 233 if err != nil { 234 fmt.Println("*****************************") 235 c.closeSubscription(sub.req.ID, fmt.Errorf("unable to decode client response: %w", err)) 236 return 237 } 238 239 // this cannot be blocking or else 240 // we will no read any other message 241 if len(sub.stream) >= cap(sub.stream) { 242 zlog.Warn("closing ws client subscription... not consuming fast en ought", 243 zap.Uint64("request_id", sub.req.ID), 244 ) 245 c.closeSubscription(sub.req.ID, fmt.Errorf("reached channel max capacity %d", len(sub.stream))) 246 return 247 } 248 249 if !sub.closed { 250 sub.stream <- result 251 } 252 return 253 } 254 255 func (c *Client) closeAllSubscription(err error) { 256 c.lock.Lock() 257 defer c.lock.Unlock() 258 259 for _, sub := range c.subscriptionByRequestID { 260 sub.err <- err 261 } 262 263 c.subscriptionByRequestID = map[uint64]*Subscription{} 264 c.subscriptionByWSSubID = map[uint64]*Subscription{} 265 } 266 267 func (c *Client) closeSubscription(reqID uint64, err error) { 268 c.lock.Lock() 269 defer c.lock.Unlock() 270 271 sub, found := c.subscriptionByRequestID[reqID] 272 if !found { 273 return 274 } 275 276 sub.err <- err 277 278 err = c.unsubscribe(sub.subID, sub.unsubscribeMethod) 279 if err != nil { 280 zlog.Warn("unable to send rpc unsubscribe call", 281 zap.Error(err), 282 ) 283 } 284 285 delete(c.subscriptionByRequestID, sub.req.ID) 286 delete(c.subscriptionByWSSubID, sub.subID) 287 } 288 289 func (c *Client) unsubscribe(subID uint64, method string) error { 290 req := newRequest([]interface{}{subID}, method, nil) 291 data, err := req.encode() 292 if err != nil { 293 return fmt.Errorf("unable to encode unsubscription message for subID %d and method %s", subID, method) 294 } 295 296 c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 297 err = c.conn.WriteMessage(websocket.TextMessage, data) 298 if err != nil { 299 return fmt.Errorf("unable to send unsubscription message for subID %d and method %s", subID, method) 300 } 301 return nil 302 } 303 304 func (c *Client) subscribe( 305 params []interface{}, 306 conf map[string]interface{}, 307 subscriptionMethod string, 308 unsubscribeMethod string, 309 decoderFunc decoderFunc, 310 ) (*Subscription, error) { 311 c.lock.Lock() 312 defer c.lock.Unlock() 313 314 req := newRequest(params, subscriptionMethod, conf) 315 data, err := req.encode() 316 if err != nil { 317 return nil, fmt.Errorf("subscribe: unable to encode subsciption request: %w", err) 318 } 319 320 sub := newSubscription( 321 req, 322 func(err error) { 323 c.closeSubscription(req.ID, err) 324 }, 325 unsubscribeMethod, 326 decoderFunc, 327 ) 328 329 c.subscriptionByRequestID[req.ID] = sub 330 zlog.Info("added new subscription to websocket client", zap.Int("count", len(c.subscriptionByRequestID))) 331 332 zlog.Debug("writing data to conn", zap.String("data", string(data))) 333 c.conn.SetWriteDeadline(time.Now().Add(writeWait)) 334 err = c.conn.WriteMessage(websocket.TextMessage, data) 335 if err != nil { 336 return nil, fmt.Errorf("unable to write request: %w", err) 337 } 338 339 return sub, nil 340 } 341 342 func decodeResponseFromReader(r io.Reader, reply interface{}) (err error) { 343 var c *response 344 if err := json.NewDecoder(r).Decode(&c); err != nil { 345 return err 346 } 347 348 if c.Error != nil { 349 jsonErr := &json2.Error{} 350 if err := json.Unmarshal(*c.Error, jsonErr); err != nil { 351 return &json2.Error{ 352 Code: json2.E_SERVER, 353 Message: string(*c.Error), 354 } 355 } 356 return jsonErr 357 } 358 359 if c.Params == nil { 360 return json2.ErrNullResult 361 } 362 363 return json.Unmarshal(*c.Params.Result, &reply) 364 } 365 366 func decodeResponseFromMessage(r []byte, reply interface{}) (err error) { 367 var c *response 368 if err := json.Unmarshal(r, &c); err != nil { 369 return err 370 } 371 372 if c.Error != nil { 373 jsonErr := &json2.Error{} 374 if err := json.Unmarshal(*c.Error, jsonErr); err != nil { 375 return &json2.Error{ 376 Code: json2.E_SERVER, 377 Message: string(*c.Error), 378 } 379 } 380 return jsonErr 381 } 382 383 if c.Params == nil { 384 return json2.ErrNullResult 385 } 386 387 return json.Unmarshal(*c.Params.Result, &reply) 388 }