github.com/klaytn/klaytn@v1.10.2/networks/rpc/handler.go (about)

     1  // Modifications Copyright 2022 The klaytn Authors
     2  // Copyright 2022 The go-ethereum Authors
     3  // This file is part of the go-ethereum library.
     4  //
     5  // The go-ethereum library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Lesser General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // The go-ethereum library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    13  // GNU Lesser General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public License
    16  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    17  //
    18  // This file is derived from rpc/handler.go (2022/08/04).
    19  // Modified and improved for the klaytn development.
    20  
    21  package rpc
    22  
    23  import (
    24  	"context"
    25  	"encoding/json"
    26  	"fmt"
    27  	"reflect"
    28  	"strconv"
    29  	"strings"
    30  	"sync"
    31  	"sync/atomic"
    32  	"time"
    33  
    34  	"github.com/klaytn/klaytn/log"
    35  )
    36  
    37  // handler handles JSON-RPC messages. There is one handler per connection. Note that
    38  // handler is not safe for concurrent use. Message handling never blocks indefinitely
    39  // because RPCs are processed on background goroutines launched by handler.
    40  //
    41  // The entry points for incoming messages are:
    42  //
    43  //	h.handleMsg(message)
    44  //	h.handleBatch(message)
    45  //
    46  // Outgoing calls use the requestOp struct. Register the request before sending it
    47  // on the connection:
    48  //
    49  //	op := &requestOp{ids: ...}
    50  //	h.addRequestOp(op)
    51  //
    52  // Now send the request, then wait for the reply to be delivered through handleMsg:
    53  //
    54  //	if err := op.wait(...); err != nil {
    55  //	    h.removeRequestOp(op) // timeout, etc.
    56  //	}
    57  type handler struct {
    58  	reg            *serviceRegistry
    59  	unsubscribeCb  *callback
    60  	idgen          func() ID                      // subscription ID generator
    61  	respWait       map[string]*requestOp          // active client requests
    62  	clientSubs     map[string]*ClientSubscription // active client subscriptions
    63  	callWG         sync.WaitGroup                 // pending call goroutines
    64  	rootCtx        context.Context                // canceled by close()
    65  	cancelRoot     func()                         // cancel function for rootCtx
    66  	conn           jsonWriter                     // where responses will be sent
    67  	allowSubscribe bool
    68  
    69  	subLock    sync.Mutex
    70  	serverSubs map[ID]*Subscription
    71  }
    72  
    73  type callProc struct {
    74  	ctx       context.Context
    75  	notifiers []*Notifier
    76  }
    77  
    78  func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler {
    79  	rootCtx, cancelRoot := context.WithCancel(connCtx)
    80  	h := &handler{
    81  		reg:            reg,
    82  		idgen:          idgen,
    83  		conn:           conn,
    84  		respWait:       make(map[string]*requestOp),
    85  		clientSubs:     make(map[string]*ClientSubscription),
    86  		rootCtx:        rootCtx,
    87  		cancelRoot:     cancelRoot,
    88  		allowSubscribe: true,
    89  		serverSubs:     make(map[ID]*Subscription),
    90  	}
    91  	h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe))
    92  	return h
    93  }
    94  
    95  // handleBatch executes all messages in a batch and returns the responses.
    96  func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
    97  	// Emit error response for empty batches:
    98  	if len(msgs) == 0 {
    99  		rpcErrorResponsesCounter.Inc(1)
   100  		h.startCallProc(func(cp *callProc) {
   101  			h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"}))
   102  		})
   103  		return
   104  	}
   105  
   106  	rpcTotalRequestsCounter.Inc(int64(len(msgs)))
   107  
   108  	// Handle non-call messages first:
   109  	calls := make([]*jsonrpcMessage, 0, len(msgs))
   110  	for _, msg := range msgs {
   111  		if handled := h.handleImmediate(msg); !handled {
   112  			calls = append(calls, msg)
   113  		}
   114  	}
   115  	if len(calls) == 0 {
   116  		return
   117  	}
   118  
   119  	if atomic.LoadInt64(&pendingRequestCount) > pendingRequestLimit {
   120  		rpcErrorResponsesCounter.Inc(int64(len(calls)))
   121  		err := &invalidRequestError{"server requests exceed the limit"}
   122  		logger.Debug(fmt.Sprintf("request error %v\n", err))
   123  		h.startCallProc(func(cp *callProc) {
   124  			h.conn.writeJSON(cp.ctx, errorMessage(err))
   125  		})
   126  		return
   127  	}
   128  
   129  	// Process calls on a goroutine because they may block indefinitely:
   130  	h.startCallProc(func(cp *callProc) {
   131  		answers := make([]*jsonrpcMessage, 0, len(msgs))
   132  		for _, msg := range calls {
   133  			if answer := h.handleCallMsg(cp, msg); answer != nil {
   134  				answers = append(answers, answer)
   135  			}
   136  		}
   137  		h.addSubscriptions(cp.notifiers)
   138  		if len(answers) > 0 {
   139  			h.conn.writeJSON(cp.ctx, answers)
   140  		}
   141  		for _, n := range cp.notifiers {
   142  			n.activate()
   143  		}
   144  	})
   145  }
   146  
   147  // handleMsg handles a single message.
   148  func (h *handler) handleMsg(msg *jsonrpcMessage) {
   149  	rpcTotalRequestsCounter.Inc(1)
   150  	if ok := h.handleImmediate(msg); ok {
   151  		return
   152  	}
   153  
   154  	if atomic.LoadInt64(&pendingRequestCount) > pendingRequestLimit {
   155  		rpcErrorResponsesCounter.Inc(1)
   156  		err := &invalidRequestError{"server requests exceed the limit"}
   157  		logger.Debug(fmt.Sprintf("request error %v\n", err))
   158  		h.startCallProc(func(cp *callProc) {
   159  			h.conn.writeJSON(cp.ctx, errorMessage(err))
   160  		})
   161  		return
   162  	}
   163  
   164  	h.startCallProc(func(cp *callProc) {
   165  		answer := h.handleCallMsg(cp, msg)
   166  		h.addSubscriptions(cp.notifiers)
   167  		if answer != nil {
   168  			h.conn.writeJSON(cp.ctx, answer)
   169  		}
   170  		for _, n := range cp.notifiers {
   171  			n.activate()
   172  		}
   173  	})
   174  }
   175  
   176  // close cancels all requests except for inflightReq and waits for
   177  // call goroutines to shut down.
   178  func (h *handler) close(err error, inflightReq *requestOp) {
   179  	h.cancelAllRequests(err, inflightReq)
   180  	h.callWG.Wait()
   181  	h.cancelRoot()
   182  	h.cancelServerSubscriptions(err)
   183  }
   184  
   185  // addRequestOp registers a request operation.
   186  func (h *handler) addRequestOp(op *requestOp) {
   187  	for _, id := range op.ids {
   188  		h.respWait[string(id)] = op
   189  	}
   190  }
   191  
   192  // removeRequestOps stops waiting for the given request IDs.
   193  func (h *handler) removeRequestOp(op *requestOp) {
   194  	for _, id := range op.ids {
   195  		delete(h.respWait, string(id))
   196  	}
   197  }
   198  
   199  // cancelAllRequests unblocks and removes pending requests and active subscriptions.
   200  func (h *handler) cancelAllRequests(err error, inflightReq *requestOp) {
   201  	didClose := make(map[*requestOp]bool)
   202  	if inflightReq != nil {
   203  		didClose[inflightReq] = true
   204  	}
   205  
   206  	for id, op := range h.respWait {
   207  		// Remove the op so that later calls will not close op.resp again.
   208  		delete(h.respWait, id)
   209  
   210  		if !didClose[op] {
   211  			op.err = err
   212  			close(op.resp)
   213  			didClose[op] = true
   214  		}
   215  	}
   216  	for id, sub := range h.clientSubs {
   217  		delete(h.clientSubs, id)
   218  		sub.quitWithError(err, false)
   219  	}
   220  }
   221  
   222  func (h *handler) addSubscriptions(nn []*Notifier) {
   223  	h.subLock.Lock()
   224  	defer h.subLock.Unlock()
   225  
   226  	for _, n := range nn {
   227  		if sub := n.takeSubscription(); sub != nil {
   228  			h.serverSubs[sub.ID] = sub
   229  		}
   230  	}
   231  }
   232  
   233  // cancelServerSubscriptions removes all subscriptions and closes their error channels.
   234  func (h *handler) cancelServerSubscriptions(err error) {
   235  	h.subLock.Lock()
   236  	defer h.subLock.Unlock()
   237  
   238  	for id, s := range h.serverSubs {
   239  		s.err <- err
   240  		close(s.err)
   241  		delete(h.serverSubs, id)
   242  	}
   243  }
   244  
   245  // startCallProc runs fn in a new goroutine and starts tracking it in the h.calls wait group.
   246  func (h *handler) startCallProc(fn func(*callProc)) {
   247  	atomic.AddInt64(&pendingRequestCount, 1)
   248  	rpcPendingRequestsCount.Inc(1)
   249  	h.callWG.Add(1)
   250  	go func() {
   251  		ctx, cancel := context.WithCancel(h.rootCtx)
   252  		defer h.callWG.Done()
   253  		defer cancel()
   254  		defer atomic.AddInt64(&pendingRequestCount, -1)
   255  		fn(&callProc{ctx: ctx})
   256  	}()
   257  }
   258  
   259  // handleImmediate executes non-call messages. It returns false if the message is a
   260  // call or requires a reply.
   261  func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
   262  	start := time.Now()
   263  	switch {
   264  	case msg.isNotification():
   265  		if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
   266  			h.handleSubscriptionResult(msg)
   267  			return true
   268  		}
   269  		return false
   270  	case msg.isResponse():
   271  		h.handleResponse(msg)
   272  		logger.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start))
   273  		return true
   274  	default:
   275  		return false
   276  	}
   277  }
   278  
   279  // handleSubscriptionResult processes subscription notifications.
   280  func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
   281  	var result subscriptionResult
   282  	if err := json.Unmarshal(msg.Params, &result); err != nil {
   283  		logger.Debug("Dropping invalid subscription message")
   284  		return
   285  	}
   286  	logger.Trace("rpc client Notification", "msg", log.Lazy{Fn: func() string {
   287  		return fmt.Sprint("<-readResp: notification ", msg)
   288  	}})
   289  	if h.clientSubs[result.ID] != nil {
   290  		h.clientSubs[result.ID].deliver(result.Result)
   291  	}
   292  }
   293  
   294  // handleResponse processes method call responses.
   295  func (h *handler) handleResponse(msg *jsonrpcMessage) {
   296  	op := h.respWait[string(msg.ID)]
   297  	if op == nil {
   298  		logger.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
   299  		return
   300  	}
   301  	logger.Trace("rpc client Response", "msg", log.Lazy{Fn: func() string {
   302  		return fmt.Sprint("<-readResp: response ", msg)
   303  	}})
   304  	delete(h.respWait, string(msg.ID))
   305  	// For normal responses, just forward the reply to Call/BatchCall.
   306  	if op.sub == nil {
   307  		op.resp <- msg
   308  		return
   309  	}
   310  	// For subscription responses, start the subscription if the server
   311  	// indicates success. KlaySubscribe gets unblocked in either case through
   312  	// the op.resp channel.
   313  	defer close(op.resp)
   314  	if msg.Error != nil {
   315  		op.err = msg.Error
   316  		return
   317  	}
   318  	if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil {
   319  		go op.sub.start()
   320  		h.clientSubs[op.sub.subid] = op.sub
   321  	}
   322  }
   323  
   324  // handleCallMsg executes a call message and returns the answer.
   325  func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
   326  	start := time.Now()
   327  	switch {
   328  	case msg.isNotification():
   329  		h.handleCall(ctx, msg)
   330  		logger.Trace("Served "+msg.Method, "duration", time.Since(start))
   331  		return nil
   332  	case msg.isCall():
   333  		resp := h.handleCall(ctx, msg)
   334  		if resp.Error != nil {
   335  			logger.Debug("Served "+msg.Method, "reqid", idForLog{msg.ID}, "duration", time.Since(start), "err", resp.Error.Message)
   336  		} else {
   337  			logger.Trace("Served "+msg.Method, "reqid", idForLog{msg.ID}, "duration", time.Since(start))
   338  		}
   339  		return resp
   340  	case msg.hasValidID():
   341  		rpcErrorResponsesCounter.Inc(1)
   342  		return msg.errorResponse(&invalidRequestError{"invalid request"})
   343  	default:
   344  		rpcErrorResponsesCounter.Inc(1)
   345  		return errorMessage(&invalidRequestError{"invalid request"})
   346  	}
   347  }
   348  
   349  // handleCall processes method calls.
   350  func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
   351  	if msg.isSubscribe() {
   352  		return h.handleSubscribe(cp, msg)
   353  	}
   354  	var callb *callback
   355  	if msg.isUnsubscribe() {
   356  		callb = h.unsubscribeCb
   357  		wsUnsubscriptionReqCounter.Inc(1)
   358  	} else {
   359  		callb = h.reg.callback(msg.Method)
   360  	}
   361  	if callb == nil {
   362  		rpcErrorResponsesCounter.Inc(1)
   363  		return msg.errorResponse(&methodNotFoundError{method: msg.Method})
   364  	}
   365  	args, err := parsePositionalArguments(msg.Params, callb.argTypes)
   366  	if err != nil {
   367  		rpcErrorResponsesCounter.Inc(1)
   368  		return msg.errorResponse(&invalidParamsError{err.Error()})
   369  	}
   370  	return h.runMethod(cp.ctx, msg, callb, args)
   371  }
   372  
   373  // handleSubscribe processes *_subscribe method calls.
   374  func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
   375  	if !h.allowSubscribe {
   376  		rpcErrorResponsesCounter.Inc(1)
   377  		return msg.errorResponse(ErrNotificationsUnsupported)
   378  	}
   379  
   380  	if int32(len(h.serverSubs)) >= MaxSubscriptionPerWSConn {
   381  		rpcErrorResponsesCounter.Inc(1)
   382  		return msg.errorResponse(&callbackError{
   383  			fmt.Sprintf("Maximum %d subscriptions are allowed for a websocket connection. "+
   384  				"The limit can be updated with 'admin_setMaxSubscriptionPerWSConn' API", MaxSubscriptionPerWSConn),
   385  		})
   386  	}
   387  
   388  	// Subscription method name is first argument.
   389  	name, err := parseSubscriptionName(msg.Params)
   390  	if err != nil {
   391  		rpcErrorResponsesCounter.Inc(1)
   392  		return msg.errorResponse(&invalidParamsError{err.Error()})
   393  	}
   394  	namespace := msg.namespace()
   395  	callb := h.reg.subscription(namespace, name)
   396  	if callb == nil {
   397  		rpcErrorResponsesCounter.Inc(1)
   398  		return msg.errorResponse(&subscriptionNotFoundError{namespace, name})
   399  	}
   400  
   401  	// Parse subscription name arg too, but remove it before calling the callback.
   402  	argTypes := append([]reflect.Type{stringType}, callb.argTypes...)
   403  	args, err := parsePositionalArguments(msg.Params, argTypes)
   404  	if err != nil {
   405  		rpcErrorResponsesCounter.Inc(1)
   406  		return msg.errorResponse(&invalidParamsError{err.Error()})
   407  	}
   408  	args = args[1:]
   409  
   410  	// Install notifier in context so the subscription handler can find it.
   411  	n := &Notifier{h: h, namespace: namespace}
   412  	cp.notifiers = append(cp.notifiers, n)
   413  	ctx := context.WithValue(cp.ctx, notifierKey{}, n)
   414  
   415  	wsSubscriptionReqCounter.Inc(1)
   416  
   417  	return h.runMethod(ctx, msg, callb, args)
   418  }
   419  
   420  // runMethod runs the Go callback for an RPC method.
   421  func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *callback, args []reflect.Value) *jsonrpcMessage {
   422  	result, err := callb.call(ctx, msg.Method, args)
   423  	if err != nil {
   424  		rpcErrorResponsesCounter.Inc(1)
   425  		return msg.errorResponse(err)
   426  	}
   427  
   428  	rpcSuccessResponsesCounter.Inc(1)
   429  	return msg.response(result)
   430  }
   431  
   432  // unsubscribe is the callback function for all *_unsubscribe calls.
   433  func (h *handler) unsubscribe(ctx context.Context, id ID) (bool, error) {
   434  	h.subLock.Lock()
   435  	defer h.subLock.Unlock()
   436  
   437  	s := h.serverSubs[id]
   438  	if s == nil {
   439  		return false, ErrSubscriptionNotFound
   440  	}
   441  	close(s.err)
   442  	delete(h.serverSubs, id)
   443  	return true, nil
   444  }
   445  
   446  type idForLog struct{ json.RawMessage }
   447  
   448  func (id idForLog) String() string {
   449  	if s, err := strconv.Unquote(string(id.RawMessage)); err == nil {
   450  		return s
   451  	}
   452  	return string(id.RawMessage)
   453  }