github.com/amazechain/amc@v0.1.3/modules/rpc/jsonrpc/client.go (about)

     1  // Copyright 2022 The AmazeChain Authors
     2  // This file is part of the AmazeChain library.
     3  //
     4  // The AmazeChain library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The AmazeChain library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the AmazeChain library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package jsonrpc
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"errors"
    23  	"fmt"
    24  	"github.com/amazechain/amc/log"
    25  	"net/url"
    26  	"reflect"
    27  	"strconv"
    28  	"sync/atomic"
    29  	"time"
    30  )
    31  
    32  var (
    33  	ErrClientQuit                = errors.New("client is closed")
    34  	ErrNoResult                  = errors.New("no result in JSON-RPC response")
    35  	ErrSubscriptionQueueOverflow = errors.New("subscription queue overflow")
    36  	errClientReconnected         = errors.New("client reconnected")
    37  	errDead                      = errors.New("connection lost")
    38  )
    39  
    40  const (
    41  	defaultDialTimeout = 10 * time.Second
    42  	subscribeTimeout   = 5 * time.Second
    43  )
    44  
    45  const (
    46  	// Subscriptions are removed when the subscriber cannot keep up.
    47  	//
    48  	// This can be worked around by supplying a channel with sufficiently sized buffer,
    49  	// but this can be inconvenient and hard to explain in the docs. Another issue with
    50  	// buffered channels is that the buffer is static even though it might not be needed
    51  	// most of the time.
    52  	//
    53  	// The approach taken here is to maintain a per-subscription linked list buffer
    54  	// shrinks on demand. If the buffer reaches the size below, the subscription is
    55  	// dropped.
    56  	maxClientSubscriptionBuffer = 20000
    57  )
    58  
    59  type Client struct {
    60  	idgen    func() ID // for subscriptions
    61  	isHTTP   bool
    62  	services *serviceRegistry
    63  
    64  	idCounter     uint32
    65  	reconnectFunc reconnectFunc
    66  
    67  	writeConn jsonWriter
    68  
    69  	close       chan struct{}
    70  	closing     chan struct{}
    71  	didClose    chan struct{}
    72  	reconnected chan ServerCodec
    73  	readOp      chan readOp
    74  	readErr     chan error
    75  	reqInit     chan *requestOp
    76  	reqSent     chan error
    77  	reqTimeout  chan *requestOp
    78  }
    79  
    80  type reconnectFunc func(ctx context.Context) (ServerCodec, error)
    81  
    82  type clientContextKey struct{}
    83  
    84  type clientConn struct {
    85  	codec   ServerCodec
    86  	handler *handler
    87  }
    88  
    89  func (c *Client) newClientConn(conn ServerCodec) *clientConn {
    90  	ctx := context.WithValue(context.Background(), clientContextKey{}, c)
    91  	handler := newHandler(ctx, conn, c.idgen, c.services)
    92  	return &clientConn{conn, handler}
    93  }
    94  
    95  func (cc *clientConn) close(err error, inflightReq *requestOp) {
    96  	cc.handler.close(err, inflightReq)
    97  	cc.codec.close()
    98  }
    99  
   100  type readOp struct {
   101  	msgs  []*jsonrpcMessage
   102  	batch bool
   103  }
   104  
   105  type requestOp struct {
   106  	ids  []json.RawMessage
   107  	err  error
   108  	resp chan *jsonrpcMessage
   109  	sub  *ClientSubscription
   110  }
   111  
   112  func (op *requestOp) wait(ctx context.Context, c *Client) (*jsonrpcMessage, error) {
   113  	select {
   114  	case <-ctx.Done():
   115  		if !c.isHTTP {
   116  			select {
   117  			case c.reqTimeout <- op:
   118  			case <-c.closing:
   119  			}
   120  		}
   121  		return nil, ctx.Err()
   122  	case resp := <-op.resp:
   123  		return resp, op.err
   124  	}
   125  }
   126  
   127  func Dial(rawurl string) (*Client, error) {
   128  	return DialContext(context.Background(), rawurl)
   129  }
   130  
   131  func DialContext(ctx context.Context, rawurl string) (*Client, error) {
   132  	u, err := url.Parse(rawurl)
   133  	if err != nil {
   134  		return nil, err
   135  	}
   136  	switch u.Scheme {
   137  	case "http", "https":
   138  		return DialHTTP(rawurl)
   139  	case "ws", "wss":
   140  		return DialWebsocket(ctx, rawurl, "")
   141  	case "":
   142  		return DialIPC(ctx, rawurl)
   143  	default:
   144  		return nil, fmt.Errorf("no known transport for URL scheme %q", u.Scheme)
   145  	}
   146  }
   147  func ClientFromContext(ctx context.Context) (*Client, bool) {
   148  	client, ok := ctx.Value(clientContextKey{}).(*Client)
   149  	return client, ok
   150  }
   151  
   152  func newClient(initctx context.Context, connect reconnectFunc) (*Client, error) {
   153  	conn, err := connect(initctx)
   154  	if err != nil {
   155  		return nil, err
   156  	}
   157  	c := initClient(conn, randomIDGenerator(), new(serviceRegistry))
   158  	c.reconnectFunc = connect
   159  	return c, nil
   160  }
   161  
   162  func initClient(conn ServerCodec, idgen func() ID, services *serviceRegistry) *Client {
   163  	_, isHTTP := conn.(*httpConn)
   164  	c := &Client{
   165  		isHTTP:      isHTTP,
   166  		idgen:       idgen,
   167  		services:    services,
   168  		writeConn:   conn,
   169  		close:       make(chan struct{}),
   170  		closing:     make(chan struct{}),
   171  		didClose:    make(chan struct{}),
   172  		reconnected: make(chan ServerCodec),
   173  		readOp:      make(chan readOp),
   174  		readErr:     make(chan error),
   175  		reqInit:     make(chan *requestOp),
   176  		reqSent:     make(chan error, 1),
   177  		reqTimeout:  make(chan *requestOp),
   178  	}
   179  	if !isHTTP {
   180  		go c.dispatch(conn)
   181  	}
   182  	return c
   183  }
   184  
   185  func (c *Client) RegisterName(name string, receiver interface{}) error {
   186  	return c.services.registerName(name, receiver)
   187  }
   188  
   189  func (c *Client) nextID() json.RawMessage {
   190  	id := atomic.AddUint32(&c.idCounter, 1)
   191  	return strconv.AppendUint(nil, uint64(id), 10)
   192  }
   193  
   194  func (c *Client) SupportedModules() (map[string]string, error) {
   195  	var result map[string]string
   196  	ctx, cancel := context.WithTimeout(context.Background(), subscribeTimeout)
   197  	defer cancel()
   198  	err := c.CallContext(ctx, &result, "rpc_modules")
   199  	return result, err
   200  }
   201  
   202  func (c *Client) Close() {
   203  	if c.isHTTP {
   204  		return
   205  	}
   206  	select {
   207  	case c.close <- struct{}{}:
   208  		<-c.didClose
   209  	case <-c.didClose:
   210  	}
   211  }
   212  
   213  func (c *Client) SetHeader(key, value string) {
   214  	if !c.isHTTP {
   215  		return
   216  	}
   217  	conn := c.writeConn.(*httpConn)
   218  	conn.mu.Lock()
   219  	conn.headers.Set(key, value)
   220  	conn.mu.Unlock()
   221  }
   222  
   223  func (c *Client) Call(result interface{}, method string, args ...interface{}) error {
   224  	ctx := context.Background()
   225  	return c.CallContext(ctx, result, method, args...)
   226  }
   227  
   228  func (c *Client) CallContext(ctx context.Context, result interface{}, method string, args ...interface{}) error {
   229  	if result != nil && reflect.TypeOf(result).Kind() != reflect.Ptr {
   230  		return fmt.Errorf("call result parameter must be pointer or nil interface: %v", result)
   231  	}
   232  	msg, err := c.newMessage(method, args...)
   233  	if err != nil {
   234  		return err
   235  	}
   236  	op := &requestOp{ids: []json.RawMessage{msg.ID}, resp: make(chan *jsonrpcMessage, 1)}
   237  
   238  	if c.isHTTP {
   239  		err = c.sendHTTP(ctx, op, msg)
   240  	} else {
   241  		err = c.send(ctx, op, msg)
   242  	}
   243  	if err != nil {
   244  		return err
   245  	}
   246  
   247  	switch resp, err := op.wait(ctx, c); {
   248  	case err != nil:
   249  		return err
   250  	case resp.Error != nil:
   251  		return resp.Error
   252  	case len(resp.Result) == 0:
   253  		return ErrNoResult
   254  	default:
   255  		return json.Unmarshal(resp.Result, &result)
   256  	}
   257  }
   258  
   259  func (c *Client) Notify(ctx context.Context, method string, args ...interface{}) error {
   260  	op := new(requestOp)
   261  	msg, err := c.newMessage(method, args...)
   262  	if err != nil {
   263  		return err
   264  	}
   265  	msg.ID = nil
   266  
   267  	if c.isHTTP {
   268  		return c.sendHTTP(ctx, op, msg)
   269  	}
   270  	return c.send(ctx, op, msg)
   271  }
   272  
   273  func (c *Client) newMessage(method string, paramsIn ...interface{}) (*jsonrpcMessage, error) {
   274  	msg := &jsonrpcMessage{Version: vsn, ID: c.nextID(), Method: method}
   275  	if paramsIn != nil { // prevent sending "params":null
   276  		var err error
   277  		if msg.Params, err = json.Marshal(paramsIn); err != nil {
   278  			return nil, err
   279  		}
   280  	}
   281  	return msg, nil
   282  }
   283  
   284  func (c *Client) send(ctx context.Context, op *requestOp, msg interface{}) error {
   285  	select {
   286  	case c.reqInit <- op:
   287  		err := c.write(ctx, msg, false)
   288  		c.reqSent <- err
   289  		return err
   290  	case <-ctx.Done():
   291  		return ctx.Err()
   292  	case <-c.closing:
   293  		return ErrClientQuit
   294  	}
   295  }
   296  
   297  func (c *Client) write(ctx context.Context, msg interface{}, retry bool) error {
   298  	if c.writeConn == nil {
   299  		if err := c.reconnect(ctx); err != nil {
   300  			return err
   301  		}
   302  	}
   303  	err := c.writeConn.writeJSON(ctx, msg)
   304  	if err != nil {
   305  		c.writeConn = nil
   306  		if !retry {
   307  			return c.write(ctx, msg, true)
   308  		}
   309  	}
   310  	return err
   311  }
   312  
   313  func (c *Client) reconnect(ctx context.Context) error {
   314  	if c.reconnectFunc == nil {
   315  		return errDead
   316  	}
   317  
   318  	if _, ok := ctx.Deadline(); !ok {
   319  		var cancel func()
   320  		ctx, cancel = context.WithTimeout(ctx, defaultDialTimeout)
   321  		defer cancel()
   322  	}
   323  	newconn, err := c.reconnectFunc(ctx)
   324  	if err != nil {
   325  		log.Debug("RPC client reconnect failed", "err", err)
   326  		return err
   327  	}
   328  	select {
   329  	case c.reconnected <- newconn:
   330  		c.writeConn = newconn
   331  		return nil
   332  	case <-c.didClose:
   333  		newconn.close()
   334  		return ErrClientQuit
   335  	}
   336  }
   337  
   338  func (c *Client) dispatch(codec ServerCodec) {
   339  	var (
   340  		lastOp      *requestOp
   341  		reqInitLock = c.reqInit
   342  		conn        = c.newClientConn(codec)
   343  		reading     = true
   344  	)
   345  	defer func() {
   346  		close(c.closing)
   347  		if reading {
   348  			conn.close(ErrClientQuit, nil)
   349  			c.drainRead()
   350  		}
   351  		close(c.didClose)
   352  	}()
   353  
   354  	go c.read(codec)
   355  
   356  	for {
   357  		select {
   358  		case <-c.close:
   359  			return
   360  
   361  		case op := <-c.readOp:
   362  			if op.batch {
   363  				conn.handler.handleBatch(op.msgs)
   364  			} else {
   365  				conn.handler.handleMsg(op.msgs[0])
   366  			}
   367  
   368  		case err := <-c.readErr:
   369  			log.Debug("RPC connection read error", "err", err)
   370  			conn.close(err, lastOp)
   371  			reading = false
   372  
   373  		case newcodec := <-c.reconnected:
   374  			log.Debug("RPC client reconnected", "reading", reading, "conn", newcodec.remoteAddr())
   375  			if reading {
   376  				conn.close(errClientReconnected, lastOp)
   377  				c.drainRead()
   378  			}
   379  			go c.read(newcodec)
   380  			reading = true
   381  			conn = c.newClientConn(newcodec)
   382  			conn.handler.addRequestOp(lastOp)
   383  
   384  		case op := <-reqInitLock:
   385  			reqInitLock = nil
   386  			lastOp = op
   387  			conn.handler.addRequestOp(op)
   388  
   389  		case err := <-c.reqSent:
   390  			if err != nil {
   391  				conn.handler.removeRequestOp(lastOp)
   392  			}
   393  			reqInitLock = c.reqInit
   394  			lastOp = nil
   395  
   396  		case op := <-c.reqTimeout:
   397  			conn.handler.removeRequestOp(op)
   398  		}
   399  	}
   400  }
   401  
   402  func (c *Client) drainRead() {
   403  	for {
   404  		select {
   405  		case <-c.readOp:
   406  		case <-c.readErr:
   407  			return
   408  		}
   409  	}
   410  }
   411  
   412  func (c *Client) read(codec ServerCodec) {
   413  	for {
   414  		msgs, batch, err := codec.readBatch()
   415  		if _, ok := err.(*json.SyntaxError); ok {
   416  			codec.writeJSON(context.Background(), errorMessage(&parseError{err.Error()}))
   417  		}
   418  		if err != nil {
   419  			c.readErr <- err
   420  			return
   421  		}
   422  		c.readOp <- readOp{msgs, batch}
   423  	}
   424  }
   425  
   426  func (c *Client) Subscribe(ctx context.Context, namespace string, channel interface{}, args ...interface{}) (*ClientSubscription, error) {
   427  	// Check type of channel first.
   428  	chanVal := reflect.ValueOf(channel)
   429  	if chanVal.Kind() != reflect.Chan || chanVal.Type().ChanDir()&reflect.SendDir == 0 {
   430  		panic(fmt.Sprintf("channel argument of Subscribe has type %T, need writable channel", channel))
   431  	}
   432  	if chanVal.IsNil() {
   433  		panic("channel given to Subscribe must not be nil")
   434  	}
   435  	if c.isHTTP {
   436  		return nil, ErrNotificationsUnsupported
   437  	}
   438  
   439  	msg, err := c.newMessage(namespace+subscribeMethodSuffix, args...)
   440  	if err != nil {
   441  		return nil, err
   442  	}
   443  	op := &requestOp{
   444  		ids:  []json.RawMessage{msg.ID},
   445  		resp: make(chan *jsonrpcMessage),
   446  		sub:  newClientSubscription(c, namespace, chanVal),
   447  	}
   448  
   449  	// Send the subscription request.
   450  	// The arrival and validity of the response is signaled on sub.quit.
   451  	if err := c.send(ctx, op, msg); err != nil {
   452  		return nil, err
   453  	}
   454  	if _, err := op.wait(ctx, c); err != nil {
   455  		return nil, err
   456  	}
   457  	return op.sub, nil
   458  }