github.com/klaytn/klaytn@v1.12.1/networks/rpc/handler.go (about)

     1  // Modifications Copyright 2022 The klaytn Authors
     2  // Copyright 2022 The go-ethereum Authors
     3  // This file is part of the go-ethereum library.
     4  //
     5  // The go-ethereum library is free software: you can redistribute it and/or modify
     6  // it under the terms of the GNU Lesser General Public License as published by
     7  // the Free Software Foundation, either version 3 of the License, or
     8  // (at your option) any later version.
     9  //
    10  // The go-ethereum library is distributed in the hope that it will be useful,
    11  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    12  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    13  // GNU Lesser General Public License for more details.
    14  //
    15  // You should have received a copy of the GNU Lesser General Public License
    16  // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>.
    17  //
    18  // This file is derived from rpc/handler.go (2022/08/04).
    19  // Modified and improved for the klaytn development.
    20  
    21  package rpc
    22  
    23  import (
    24  	"context"
    25  	"encoding/json"
    26  	"fmt"
    27  	"reflect"
    28  	"strconv"
    29  	"strings"
    30  	"sync"
    31  	"sync/atomic"
    32  	"time"
    33  
    34  	"github.com/klaytn/klaytn/log"
    35  	"github.com/klaytn/klaytn/storage/statedb"
    36  )
    37  
    38  // handler handles JSON-RPC messages. There is one handler per connection. Note that
    39  // handler is not safe for concurrent use. Message handling never blocks indefinitely
    40  // because RPCs are processed on background goroutines launched by handler.
    41  //
    42  // The entry points for incoming messages are:
    43  //
    44  //	h.handleMsg(message)
    45  //	h.handleBatch(message)
    46  //
    47  // Outgoing calls use the requestOp struct. Register the request before sending it
    48  // on the connection:
    49  //
    50  //	op := &requestOp{ids: ...}
    51  //	h.addRequestOp(op)
    52  //
    53  // Now send the request, then wait for the reply to be delivered through handleMsg:
    54  //
    55  //	if err := op.wait(...); err != nil {
    56  //	    h.removeRequestOp(op) // timeout, etc.
    57  //	}
    58  type handler struct {
    59  	reg            *serviceRegistry
    60  	unsubscribeCb  *callback
    61  	idgen          func() ID                      // subscription ID generator
    62  	respWait       map[string]*requestOp          // active client requests
    63  	clientSubs     map[string]*ClientSubscription // active client subscriptions
    64  	callWG         sync.WaitGroup                 // pending call goroutines
    65  	rootCtx        context.Context                // canceled by close()
    66  	cancelRoot     func()                         // cancel function for rootCtx
    67  	conn           jsonWriter                     // where responses will be sent
    68  	allowSubscribe bool
    69  
    70  	subLock    sync.Mutex
    71  	serverSubs map[ID]*Subscription
    72  }
    73  
    74  type callProc struct {
    75  	ctx       context.Context
    76  	notifiers []*Notifier
    77  }
    78  
    79  func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler {
    80  	rootCtx, cancelRoot := context.WithCancel(connCtx)
    81  	h := &handler{
    82  		reg:            reg,
    83  		idgen:          idgen,
    84  		conn:           conn,
    85  		respWait:       make(map[string]*requestOp),
    86  		clientSubs:     make(map[string]*ClientSubscription),
    87  		rootCtx:        rootCtx,
    88  		cancelRoot:     cancelRoot,
    89  		allowSubscribe: true,
    90  		serverSubs:     make(map[ID]*Subscription),
    91  	}
    92  	h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe))
    93  	return h
    94  }
    95  
    96  // handleBatch executes all messages in a batch and returns the responses.
    97  func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
    98  	// Emit error response for empty batches:
    99  	if len(msgs) == 0 {
   100  		rpcErrorResponsesCounter.Inc(1)
   101  		h.startCallProc(func(cp *callProc) {
   102  			h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"}))
   103  		})
   104  		return
   105  	}
   106  
   107  	rpcTotalRequestsCounter.Inc(int64(len(msgs)))
   108  
   109  	// Handle non-call messages first:
   110  	calls := make([]*jsonrpcMessage, 0, len(msgs))
   111  	for _, msg := range msgs {
   112  		if handled := h.handleImmediate(msg); !handled {
   113  			calls = append(calls, msg)
   114  		}
   115  	}
   116  	if len(calls) == 0 {
   117  		return
   118  	}
   119  
   120  	if atomic.LoadInt64(&pendingRequestCount) > pendingRequestLimit {
   121  		rpcErrorResponsesCounter.Inc(int64(len(calls)))
   122  		err := &invalidRequestError{"server requests exceed the limit"}
   123  		logger.Debug(fmt.Sprintf("request error %v\n", err))
   124  		h.startCallProc(func(cp *callProc) {
   125  			h.conn.writeJSON(cp.ctx, errorMessage(err))
   126  		})
   127  		return
   128  	}
   129  
   130  	// Process calls on a goroutine because they may block indefinitely:
   131  	h.startCallProc(func(cp *callProc) {
   132  		answers := make([]*jsonrpcMessage, 0, len(msgs))
   133  		for _, msg := range calls {
   134  			if answer := h.handleCallMsg(cp, msg); answer != nil {
   135  				answers = append(answers, answer)
   136  			}
   137  		}
   138  		h.addSubscriptions(cp.notifiers)
   139  		if len(answers) > 0 {
   140  			h.conn.writeJSON(cp.ctx, answers)
   141  		}
   142  		for _, n := range cp.notifiers {
   143  			n.activate()
   144  		}
   145  	})
   146  }
   147  
   148  // handleMsg handles a single message.
   149  func (h *handler) handleMsg(msg *jsonrpcMessage) {
   150  	rpcTotalRequestsCounter.Inc(1)
   151  	if ok := h.handleImmediate(msg); ok {
   152  		return
   153  	}
   154  
   155  	if atomic.LoadInt64(&pendingRequestCount) > pendingRequestLimit {
   156  		rpcErrorResponsesCounter.Inc(1)
   157  		err := &invalidRequestError{"server requests exceed the limit"}
   158  		logger.Debug(fmt.Sprintf("request error %v\n", err))
   159  		h.startCallProc(func(cp *callProc) {
   160  			h.conn.writeJSON(cp.ctx, errorMessage(err))
   161  		})
   162  		return
   163  	}
   164  
   165  	h.startCallProc(func(cp *callProc) {
   166  		answer := h.handleCallMsg(cp, msg)
   167  		h.addSubscriptions(cp.notifiers)
   168  		if answer != nil {
   169  			h.conn.writeJSON(cp.ctx, answer)
   170  		}
   171  		for _, n := range cp.notifiers {
   172  			n.activate()
   173  		}
   174  	})
   175  }
   176  
   177  // close cancels all requests except for inflightReq and waits for
   178  // call goroutines to shut down.
   179  func (h *handler) close(err error, inflightReq *requestOp) {
   180  	h.cancelAllRequests(err, inflightReq)
   181  	h.callWG.Wait()
   182  	h.cancelRoot()
   183  	h.cancelServerSubscriptions(err)
   184  }
   185  
   186  // addRequestOp registers a request operation.
   187  func (h *handler) addRequestOp(op *requestOp) {
   188  	for _, id := range op.ids {
   189  		h.respWait[string(id)] = op
   190  	}
   191  }
   192  
   193  // removeRequestOps stops waiting for the given request IDs.
   194  func (h *handler) removeRequestOp(op *requestOp) {
   195  	for _, id := range op.ids {
   196  		delete(h.respWait, string(id))
   197  	}
   198  }
   199  
   200  // cancelAllRequests unblocks and removes pending requests and active subscriptions.
   201  func (h *handler) cancelAllRequests(err error, inflightReq *requestOp) {
   202  	didClose := make(map[*requestOp]bool)
   203  	if inflightReq != nil {
   204  		didClose[inflightReq] = true
   205  	}
   206  
   207  	for id, op := range h.respWait {
   208  		// Remove the op so that later calls will not close op.resp again.
   209  		delete(h.respWait, id)
   210  
   211  		if !didClose[op] {
   212  			op.err = err
   213  			close(op.resp)
   214  			didClose[op] = true
   215  		}
   216  	}
   217  	for id, sub := range h.clientSubs {
   218  		delete(h.clientSubs, id)
   219  		sub.quitWithError(err, false)
   220  	}
   221  }
   222  
   223  func (h *handler) addSubscriptions(nn []*Notifier) {
   224  	h.subLock.Lock()
   225  	defer h.subLock.Unlock()
   226  
   227  	for _, n := range nn {
   228  		if sub := n.takeSubscription(); sub != nil {
   229  			h.serverSubs[sub.ID] = sub
   230  		}
   231  	}
   232  }
   233  
   234  // cancelServerSubscriptions removes all subscriptions and closes their error channels.
   235  func (h *handler) cancelServerSubscriptions(err error) {
   236  	h.subLock.Lock()
   237  	defer h.subLock.Unlock()
   238  
   239  	for id, s := range h.serverSubs {
   240  		s.err <- err
   241  		close(s.err)
   242  		delete(h.serverSubs, id)
   243  	}
   244  }
   245  
   246  // startCallProc runs fn in a new goroutine and starts tracking it in the h.calls wait group.
   247  func (h *handler) startCallProc(fn func(*callProc)) {
   248  	atomic.AddInt64(&pendingRequestCount, 1)
   249  	rpcPendingRequestsCount.Inc(1)
   250  	h.callWG.Add(1)
   251  	go func() {
   252  		ctx, cancel := context.WithCancel(h.rootCtx)
   253  		defer h.callWG.Done()
   254  		defer cancel()
   255  		defer atomic.AddInt64(&pendingRequestCount, -1)
   256  		fn(&callProc{ctx: ctx})
   257  	}()
   258  }
   259  
   260  // handleImmediate executes non-call messages. It returns false if the message is a
   261  // call or requires a reply.
   262  func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
   263  	start := time.Now()
   264  	switch {
   265  	case msg.isNotification():
   266  		if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
   267  			h.handleSubscriptionResult(msg)
   268  			return true
   269  		}
   270  		return false
   271  	case msg.isResponse():
   272  		h.handleResponse(msg)
   273  		logger.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start))
   274  		return true
   275  	default:
   276  		return false
   277  	}
   278  }
   279  
   280  // handleSubscriptionResult processes subscription notifications.
   281  func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
   282  	var result subscriptionResult
   283  	if err := json.Unmarshal(msg.Params, &result); err != nil {
   284  		logger.Debug("Dropping invalid subscription message")
   285  		return
   286  	}
   287  	logger.Trace("rpc client Notification", "msg", log.Lazy{Fn: func() string {
   288  		return fmt.Sprint("<-readResp: notification ", msg)
   289  	}})
   290  	if h.clientSubs[result.ID] != nil {
   291  		h.clientSubs[result.ID].deliver(result.Result)
   292  	}
   293  }
   294  
   295  // handleResponse processes method call responses.
   296  func (h *handler) handleResponse(msg *jsonrpcMessage) {
   297  	op := h.respWait[string(msg.ID)]
   298  	if op == nil {
   299  		logger.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
   300  		return
   301  	}
   302  	logger.Trace("rpc client Response", "msg", log.Lazy{Fn: func() string {
   303  		return fmt.Sprint("<-readResp: response ", msg)
   304  	}})
   305  	delete(h.respWait, string(msg.ID))
   306  	// For normal responses, just forward the reply to Call/BatchCall.
   307  	if op.sub == nil {
   308  		op.resp <- msg
   309  		return
   310  	}
   311  	// For subscription responses, start the subscription if the server
   312  	// indicates success. KlaySubscribe gets unblocked in either case through
   313  	// the op.resp channel.
   314  	defer close(op.resp)
   315  	if msg.Error != nil {
   316  		op.err = msg.Error
   317  		return
   318  	}
   319  	if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil {
   320  		go op.sub.start()
   321  		h.clientSubs[op.sub.subid] = op.sub
   322  	}
   323  }
   324  
   325  // handleCallMsg executes a call message and returns the answer.
   326  func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
   327  	start := time.Now()
   328  	switch {
   329  	case msg.isNotification():
   330  		h.handleCall(ctx, msg)
   331  		logger.Trace("Served "+msg.Method, "duration", time.Since(start))
   332  		return nil
   333  	case msg.isCall():
   334  		resp := h.handleCall(ctx, msg)
   335  		var ctx []interface{}
   336  		ctx = append(ctx, "reqid", idForLog{msg.ID}, "duration", time.Since(start))
   337  		if resp.Error != nil {
   338  			ctx = append(ctx, "err", resp.Error.Message)
   339  			if resp.Error.Data != nil {
   340  				ctx = append(ctx, "errdata", resp.Error.Data)
   341  			}
   342  			logger.Debug("Served "+msg.Method, ctx...)
   343  		} else {
   344  			logger.Trace("Served "+msg.Method, ctx...)
   345  		}
   346  		return resp
   347  	case msg.hasValidID():
   348  		rpcErrorResponsesCounter.Inc(1)
   349  		return msg.errorResponse(&invalidRequestError{"invalid request"})
   350  	default:
   351  		rpcErrorResponsesCounter.Inc(1)
   352  		return errorMessage(&invalidRequestError{"invalid request"})
   353  	}
   354  }
   355  
   356  // handleCall processes method calls.
   357  func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
   358  	if msg.isSubscribe() {
   359  		return h.handleSubscribe(cp, msg)
   360  	}
   361  	var callb *callback
   362  	if msg.isUnsubscribe() {
   363  		callb = h.unsubscribeCb
   364  		wsUnsubscriptionReqCounter.Inc(1)
   365  	} else {
   366  		callb = h.reg.callback(msg.Method)
   367  	}
   368  	if callb == nil {
   369  		rpcErrorResponsesCounter.Inc(1)
   370  		return msg.errorResponse(&methodNotFoundError{method: msg.Method})
   371  	}
   372  	args, err := parsePositionalArguments(msg.Params, callb.argTypes)
   373  	if err != nil {
   374  		rpcErrorResponsesCounter.Inc(1)
   375  		return msg.errorResponse(&invalidParamsError{err.Error()})
   376  	}
   377  	return h.runMethod(cp.ctx, msg, callb, args)
   378  }
   379  
   380  // handleSubscribe processes *_subscribe method calls.
   381  func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
   382  	if !h.allowSubscribe {
   383  		rpcErrorResponsesCounter.Inc(1)
   384  		return msg.errorResponse(ErrNotificationsUnsupported)
   385  	}
   386  
   387  	if int32(len(h.serverSubs)) >= MaxSubscriptionPerWSConn {
   388  		rpcErrorResponsesCounter.Inc(1)
   389  		return msg.errorResponse(&callbackError{
   390  			fmt.Sprintf("Maximum %d subscriptions are allowed for a websocket connection. "+
   391  				"The limit can be updated with 'admin_setMaxSubscriptionPerWSConn' API", MaxSubscriptionPerWSConn),
   392  		})
   393  	}
   394  
   395  	// Subscription method name is first argument.
   396  	name, err := parseSubscriptionName(msg.Params)
   397  	if err != nil {
   398  		rpcErrorResponsesCounter.Inc(1)
   399  		return msg.errorResponse(&invalidParamsError{err.Error()})
   400  	}
   401  	namespace := msg.namespace()
   402  	callb := h.reg.subscription(namespace, name)
   403  	if callb == nil {
   404  		rpcErrorResponsesCounter.Inc(1)
   405  		return msg.errorResponse(&subscriptionNotFoundError{namespace, name})
   406  	}
   407  
   408  	// Parse subscription name arg too, but remove it before calling the callback.
   409  	argTypes := append([]reflect.Type{stringType}, callb.argTypes...)
   410  	args, err := parsePositionalArguments(msg.Params, argTypes)
   411  	if err != nil {
   412  		rpcErrorResponsesCounter.Inc(1)
   413  		return msg.errorResponse(&invalidParamsError{err.Error()})
   414  	}
   415  	args = args[1:]
   416  
   417  	// Install notifier in context so the subscription handler can find it.
   418  	n := &Notifier{h: h, namespace: namespace}
   419  	cp.notifiers = append(cp.notifiers, n)
   420  	ctx := context.WithValue(cp.ctx, notifierKey{}, n)
   421  
   422  	wsSubscriptionReqCounter.Inc(1)
   423  
   424  	return h.runMethod(ctx, msg, callb, args)
   425  }
   426  
   427  // runMethod runs the Go callback for an RPC method.
   428  func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *callback, args []reflect.Value) *jsonrpcMessage {
   429  	result, err := callb.call(ctx, msg.Method, args)
   430  	if err != nil {
   431  		// TODO-Klaytn:
   432  		// 1. The single URL may be extended to the list of URL.
   433  		// 2. The URL (list or single) seems not modifiable at runtime.
   434  		// 3. Any request can be relayed as well as the state lack.
   435  		// 4. Make a new rpcErrorResponse struct.
   436  		if UpstreamArchiveEN != "" && shouldRequestUpstream(err) {
   437  			return requestUpstream(ctx, msg, args)
   438  		}
   439  		rpcErrorResponsesCounter.Inc(1)
   440  		return msg.errorResponse(err)
   441  	}
   442  
   443  	rpcSuccessResponsesCounter.Inc(1)
   444  	return msg.response(result)
   445  }
   446  
   447  // shouldRequestUpstream is a function that determines whether must be requested upstream.
   448  func shouldRequestUpstream(err error) bool {
   449  	switch err.(type) {
   450  	case *statedb.MissingNodeError:
   451  		return true
   452  	default:
   453  		return false
   454  	}
   455  }
   456  
   457  // requestUpstream is the function to request upstream archive en
   458  func requestUpstream(ctx context.Context, msg *jsonrpcMessage, args []reflect.Value) *jsonrpcMessage {
   459  	ctx, cancel := context.WithTimeout(ctx, DefaultHTTPTimeouts.ExecutionTimeout)
   460  	defer cancel()
   461  
   462  	var result interface{}
   463  	c, err := DialContext(ctx, UpstreamArchiveEN)
   464  	if err == nil {
   465  		defer c.Close()
   466  
   467  		var params []interface{}
   468  		for _, a := range args {
   469  			params = append(params, a.Interface())
   470  		}
   471  		if err = c.CallContext(ctx, &result, msg.Method, params...); err == nil {
   472  			rpcSuccessResponsesCounter.Inc(1)
   473  			return msg.response(result)
   474  		}
   475  	}
   476  
   477  	rpcErrorResponsesCounter.Inc(1)
   478  	return msg.errorResponse(err)
   479  }
   480  
   481  // unsubscribe is the callback function for all *_unsubscribe calls.
   482  func (h *handler) unsubscribe(ctx context.Context, id ID) (bool, error) {
   483  	h.subLock.Lock()
   484  	defer h.subLock.Unlock()
   485  
   486  	s := h.serverSubs[id]
   487  	if s == nil {
   488  		return false, ErrSubscriptionNotFound
   489  	}
   490  	close(s.err)
   491  	delete(h.serverSubs, id)
   492  	return true, nil
   493  }
   494  
   495  type idForLog struct{ json.RawMessage }
   496  
   497  func (id idForLog) String() string {
   498  	if s, err := strconv.Unquote(string(id.RawMessage)); err == nil {
   499  		return s
   500  	}
   501  	return string(id.RawMessage)
   502  }