github.com/amazechain/amc@v0.1.3/modules/rpc/jsonrpc/client.go (about) 1 // Copyright 2022 The AmazeChain Authors 2 // This file is part of the AmazeChain library. 3 // 4 // The AmazeChain library is free software: you can redistribute it and/or modify 5 // it under the terms of the GNU Lesser General Public License as published by 6 // the Free Software Foundation, either version 3 of the License, or 7 // (at your option) any later version. 8 // 9 // The AmazeChain library is distributed in the hope that it will be useful, 10 // but WITHOUT ANY WARRANTY; without even the implied warranty of 11 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 12 // GNU Lesser General Public License for more details. 13 // 14 // You should have received a copy of the GNU Lesser General Public License 15 // along with the AmazeChain library. If not, see <http://www.gnu.org/licenses/>. 16 17 package jsonrpc 18 19 import ( 20 "context" 21 "encoding/json" 22 "errors" 23 "fmt" 24 "github.com/amazechain/amc/log" 25 "net/url" 26 "reflect" 27 "strconv" 28 "sync/atomic" 29 "time" 30 ) 31 32 var ( 33 ErrClientQuit = errors.New("client is closed") 34 ErrNoResult = errors.New("no result in JSON-RPC response") 35 ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow") 36 errClientReconnected = errors.New("client reconnected") 37 errDead = errors.New("connection lost") 38 ) 39 40 const ( 41 defaultDialTimeout = 10 * time.Second 42 subscribeTimeout = 5 * time.Second 43 ) 44 45 const ( 46 // Subscriptions are removed when the subscriber cannot keep up. 47 // 48 // This can be worked around by supplying a channel with sufficiently sized buffer, 49 // but this can be inconvenient and hard to explain in the docs. Another issue with 50 // buffered channels is that the buffer is static even though it might not be needed 51 // most of the time. 52 // 53 // The approach taken here is to maintain a per-subscription linked list buffer 54 // shrinks on demand. If the buffer reaches the size below, the subscription is 55 // dropped. 56 maxClientSubscriptionBuffer = 20000 57 ) 58 59 type Client struct { 60 idgen func() ID // for subscriptions 61 isHTTP bool 62 services *serviceRegistry 63 64 idCounter uint32 65 reconnectFunc reconnectFunc 66 67 writeConn jsonWriter 68 69 close chan struct{} 70 closing chan struct{} 71 didClose chan struct{} 72 reconnected chan ServerCodec 73 readOp chan readOp 74 readErr chan error 75 reqInit chan *requestOp 76 reqSent chan error 77 reqTimeout chan *requestOp 78 } 79 80 type reconnectFunc func(ctx context.Context) (ServerCodec, error) 81 82 type clientContextKey struct{} 83 84 type clientConn struct { 85 codec ServerCodec 86 handler *handler 87 } 88 89 func (c *Client) newClientConn(conn ServerCodec) *clientConn { 90 ctx := context.WithValue(context.Background(), clientContextKey{}, c) 91 handler := newHandler(ctx, conn, c.idgen, c.services) 92 return &clientConn{conn, handler} 93 } 94 95 func (cc *clientConn) close(err error, inflightReq *requestOp) { 96 cc.handler.close(err, inflightReq) 97 cc.codec.close() 98 } 99 100 type readOp struct { 101 msgs []*jsonrpcMessage 102 batch bool 103 } 104 105 type requestOp struct { 106 ids []json.RawMessage 107 err error 108 resp chan *jsonrpcMessage 109 sub *ClientSubscription 110 } 111 112 func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) { 113 select { 114 case <-ctx.Done(): 115 if !c.isHTTP { 116 select { 117 case c.reqTimeout <- op: 118 case <-c.closing: 119 } 120 } 121 return nil, ctx.Err() 122 case resp := <-op.resp: 123 return resp, op.err 124 } 125 } 126 127 func Dial(rawurl string) (*Client, error) { 128 return DialContext(context.Background(), rawurl) 129 } 130 131 func DialContext(ctx context.Context, rawurl string) (*Client, error) { 132 u, err := url.Parse(rawurl) 133 if err != nil { 134 return nil, err 135 } 136 switch u.Scheme { 137 case "http", "https": 138 return DialHTTP(rawurl) 139 case "ws", "wss": 140 return DialWebsocket(ctx, rawurl, "") 141 case "": 142 return DialIPC(ctx, rawurl) 143 default: 144 return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme) 145 } 146 } 147 func ClientFromContext(ctx context.Context) (*Client, bool) { 148 client, ok := ctx.Value(clientContextKey{}).(*Client) 149 return client, ok 150 } 151 152 func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) { 153 conn, err := connect(initctx) 154 if err != nil { 155 return nil, err 156 } 157 c := initClient(conn, randomIDGenerator(), new(serviceRegistry)) 158 c.reconnectFunc = connect 159 return c, nil 160 } 161 162 func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client { 163 _, isHTTP := conn.(*httpConn) 164 c := &Client{ 165 isHTTP: isHTTP, 166 idgen: idgen, 167 services: services, 168 writeConn: conn, 169 close: make(chan struct{}), 170 closing: make(chan struct{}), 171 didClose: make(chan struct{}), 172 reconnected: make(chan ServerCodec), 173 readOp: make(chan readOp), 174 readErr: make(chan error), 175 reqInit: make(chan *requestOp), 176 reqSent: make(chan error, 1), 177 reqTimeout: make(chan *requestOp), 178 } 179 if !isHTTP { 180 go c.dispatch(conn) 181 } 182 return c 183 } 184 185 func (c *Client) RegisterName(name string, receiver interface{}) error { 186 return c.services.registerName(name, receiver) 187 } 188 189 func (c *Client) nextID() json.RawMessage { 190 id := atomic.AddUint32(&c.idCounter, 1) 191 return strconv.AppendUint(nil, uint64(id), 10) 192 } 193 194 func (c *Client) SupportedModules() (map[string]string, error) { 195 var result map[string]string 196 ctx, cancel := context.WithTimeout(context.Background(), subscribeTimeout) 197 defer cancel() 198 err := c.CallContext(ctx, &result, "rpc_modules") 199 return result, err 200 } 201 202 func (c *Client) Close() { 203 if c.isHTTP { 204 return 205 } 206 select { 207 case c.close <- struct{}{}: 208 <-c.didClose 209 case <-c.didClose: 210 } 211 } 212 213 func (c *Client) SetHeader(key, value string) { 214 if !c.isHTTP { 215 return 216 } 217 conn := c.writeConn.(*httpConn) 218 conn.mu.Lock() 219 conn.headers.Set(key, value) 220 conn.mu.Unlock() 221 } 222 223 func (c *Client) Call(result interface{}, method string, args ...interface{}) error { 224 ctx := context.Background() 225 return c.CallContext(ctx, result, method, args...) 226 } 227 228 func (c *Client) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error { 229 if result != nil && reflect.TypeOf(result).Kind() != reflect.Ptr { 230 return fmt.Errorf("call result parameter must be pointer or nil interface: %v", result) 231 } 232 msg, err := c.newMessage(method, args...) 233 if err != nil { 234 return err 235 } 236 op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)} 237 238 if c.isHTTP { 239 err = c.sendHTTP(ctx, op, msg) 240 } else { 241 err = c.send(ctx, op, msg) 242 } 243 if err != nil { 244 return err 245 } 246 247 switch resp, err := op.wait(ctx, c); { 248 case err != nil: 249 return err 250 case resp.Error != nil: 251 return resp.Error 252 case len(resp.Result) == 0: 253 return ErrNoResult 254 default: 255 return json.Unmarshal(resp.Result, &result) 256 } 257 } 258 259 func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) error { 260 op := new(requestOp) 261 msg, err := c.newMessage(method, args...) 262 if err != nil { 263 return err 264 } 265 msg.ID = nil 266 267 if c.isHTTP { 268 return c.sendHTTP(ctx, op, msg) 269 } 270 return c.send(ctx, op, msg) 271 } 272 273 func (c *Client) newMessage(method string, paramsIn ...interface{}) (*jsonrpcMessage, error) { 274 msg := &jsonrpcMessage{Version: vsn, ID: c.nextID(), Method: method} 275 if paramsIn != nil { // prevent sending "params":null 276 var err error 277 if msg.Params, err = json.Marshal(paramsIn); err != nil { 278 return nil, err 279 } 280 } 281 return msg, nil 282 } 283 284 func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error { 285 select { 286 case c.reqInit <- op: 287 err := c.write(ctx, msg, false) 288 c.reqSent <- err 289 return err 290 case <-ctx.Done(): 291 return ctx.Err() 292 case <-c.closing: 293 return ErrClientQuit 294 } 295 } 296 297 func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error { 298 if c.writeConn == nil { 299 if err := c.reconnect(ctx); err != nil { 300 return err 301 } 302 } 303 err := c.writeConn.writeJSON(ctx, msg) 304 if err != nil { 305 c.writeConn = nil 306 if !retry { 307 return c.write(ctx, msg, true) 308 } 309 } 310 return err 311 } 312 313 func (c *Client) reconnect(ctx context.Context) error { 314 if c.reconnectFunc == nil { 315 return errDead 316 } 317 318 if _, ok := ctx.Deadline(); !ok { 319 var cancel func() 320 ctx, cancel = context.WithTimeout(ctx, defaultDialTimeout) 321 defer cancel() 322 } 323 newconn, err := c.reconnectFunc(ctx) 324 if err != nil { 325 log.Debug("RPC client reconnect failed", "err", err) 326 return err 327 } 328 select { 329 case c.reconnected <- newconn: 330 c.writeConn = newconn 331 return nil 332 case <-c.didClose: 333 newconn.close() 334 return ErrClientQuit 335 } 336 } 337 338 func (c *Client) dispatch(codec ServerCodec) { 339 var ( 340 lastOp *requestOp 341 reqInitLock = c.reqInit 342 conn = c.newClientConn(codec) 343 reading = true 344 ) 345 defer func() { 346 close(c.closing) 347 if reading { 348 conn.close(ErrClientQuit, nil) 349 c.drainRead() 350 } 351 close(c.didClose) 352 }() 353 354 go c.read(codec) 355 356 for { 357 select { 358 case <-c.close: 359 return 360 361 case op := <-c.readOp: 362 if op.batch { 363 conn.handler.handleBatch(op.msgs) 364 } else { 365 conn.handler.handleMsg(op.msgs[0]) 366 } 367 368 case err := <-c.readErr: 369 log.Debug("RPC connection read error", "err", err) 370 conn.close(err, lastOp) 371 reading = false 372 373 case newcodec := <-c.reconnected: 374 log.Debug("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr()) 375 if reading { 376 conn.close(errClientReconnected, lastOp) 377 c.drainRead() 378 } 379 go c.read(newcodec) 380 reading = true 381 conn = c.newClientConn(newcodec) 382 conn.handler.addRequestOp(lastOp) 383 384 case op := <-reqInitLock: 385 reqInitLock = nil 386 lastOp = op 387 conn.handler.addRequestOp(op) 388 389 case err := <-c.reqSent: 390 if err != nil { 391 conn.handler.removeRequestOp(lastOp) 392 } 393 reqInitLock = c.reqInit 394 lastOp = nil 395 396 case op := <-c.reqTimeout: 397 conn.handler.removeRequestOp(op) 398 } 399 } 400 } 401 402 func (c *Client) drainRead() { 403 for { 404 select { 405 case <-c.readOp: 406 case <-c.readErr: 407 return 408 } 409 } 410 } 411 412 func (c *Client) read(codec ServerCodec) { 413 for { 414 msgs, batch, err := codec.readBatch() 415 if _, ok := err.(*json.SyntaxError); ok { 416 codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()})) 417 } 418 if err != nil { 419 c.readErr <- err 420 return 421 } 422 c.readOp <- readOp{msgs, batch} 423 } 424 } 425 426 func (c *Client) Subscribe(ctx context.Context, namespace string, channel interface{}, args ...interface{}) (*ClientSubscription, error) { 427 // Check type of channel first. 428 chanVal := reflect.ValueOf(channel) 429 if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 { 430 panic(fmt.Sprintf("channel argument of Subscribe has type %T, need writable channel", channel)) 431 } 432 if chanVal.IsNil() { 433 panic("channel given to Subscribe must not be nil") 434 } 435 if c.isHTTP { 436 return nil, ErrNotificationsUnsupported 437 } 438 439 msg, err := c.newMessage(namespace+subscribeMethodSuffix, args...) 440 if err != nil { 441 return nil, err 442 } 443 op := &requestOp{ 444 ids: []json.RawMessage{msg.ID}, 445 resp: make(chan *jsonrpcMessage), 446 sub: newClientSubscription(c, namespace, chanVal), 447 } 448 449 // Send the subscription request. 450 // The arrival and validity of the response is signaled on sub.quit. 451 if err := c.send(ctx, op, msg); err != nil { 452 return nil, err 453 } 454 if _, err := op.wait(ctx, c); err != nil { 455 return nil, err 456 } 457 return op.sub, nil 458 }