github.com/amazechain/amc@v0.1.3/modules/rpc/jsonrpc/handler.go (about)

     1  // Copyright 2022 The AmazeChain Authors
     2  // This file is part of the AmazeChain library.
     3  //
     4  // The AmazeChain library is free software: you can redistribute it and/or modify
     5  // it under the terms of the GNU Lesser General Public License as published by
     6  // the Free Software Foundation, either version 3 of the License, or
     7  // (at your option) any later version.
     8  //
     9  // The AmazeChain library is distributed in the hope that it will be useful,
    10  // but WITHOUT ANY WARRANTY; without even the implied warranty of
    11  // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
    12  // GNU Lesser General Public License for more details.
    13  //
    14  // You should have received a copy of the GNU Lesser General Public License
    15  // along with the AmazeChain library. If not, see <http://www.gnu.org/licenses/>.
    16  
    17  package jsonrpc
    18  
    19  import (
    20  	"context"
    21  	"encoding/json"
    22  	"reflect"
    23  	"strconv"
    24  	"strings"
    25  	"sync"
    26  	"time"
    27  
    28  	"github.com/amazechain/amc/log"
    29  )
    30  
    31  type handler struct {
    32  	reg            *serviceRegistry
    33  	unsubscribeCb  *callback
    34  	idgen          func() ID             // subscription ID generator
    35  	respWait       map[string]*requestOp // active client requests
    36  	callWG         sync.WaitGroup        // pending call goroutines
    37  	rootCtx        context.Context       // canceled by close()
    38  	cancelRoot     func()                // cancel function for rootCtx
    39  	conn           jsonWriter            // where responses will be sent
    40  	allowSubscribe bool
    41  
    42  	subLock    sync.Mutex
    43  	serverSubs map[ID]*Subscription
    44  	clientSubs map[string]*ClientSubscription // active client subscriptions
    45  
    46  	log log.Logger
    47  }
    48  
    49  type callProc struct {
    50  	ctx       context.Context
    51  	notifiers []*Notifier
    52  }
    53  
    54  func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler {
    55  	rootCtx, cancelRoot := context.WithCancel(connCtx)
    56  	h := &handler{
    57  		reg:            reg,
    58  		idgen:          idgen,
    59  		conn:           conn,
    60  		respWait:       make(map[string]*requestOp),
    61  		rootCtx:        rootCtx,
    62  		cancelRoot:     cancelRoot,
    63  		allowSubscribe: true,
    64  		serverSubs:     make(map[ID]*Subscription),
    65  		clientSubs:     make(map[string]*ClientSubscription),
    66  		log:            log.Root(),
    67  	}
    68  	if conn.remoteAddr() != "" {
    69  		h.log = h.log.New("conn", conn.remoteAddr())
    70  	}
    71  	h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe))
    72  	return h
    73  }
    74  
    75  // handleBatch executes all messages in a batch and returns the responses.
    76  func (h *handler) handleBatch(msgs []*jsonrpcMessage) {
    77  	// Emit error response for empty batches:
    78  	if len(msgs) == 0 {
    79  		h.startCallProc(func(cp *callProc) {
    80  			h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"}))
    81  		})
    82  		return
    83  	}
    84  
    85  	// Handle non-call messages first:
    86  	calls := make([]*jsonrpcMessage, 0, len(msgs))
    87  	for _, msg := range msgs {
    88  		if handled := h.handleImmediate(msg); !handled {
    89  			calls = append(calls, msg)
    90  		}
    91  	}
    92  	if len(calls) == 0 {
    93  		return
    94  	}
    95  	// Process calls on a goroutine because they may block indefinitely:
    96  	h.startCallProc(func(cp *callProc) {
    97  		answers := make([]*jsonrpcMessage, 0, len(msgs))
    98  		for _, msg := range calls {
    99  			if answer := h.handleCallMsg(cp, msg); answer != nil {
   100  				answers = append(answers, answer)
   101  			}
   102  		}
   103  		h.addSubscriptions(cp.notifiers)
   104  		if len(answers) > 0 {
   105  			h.conn.writeJSON(cp.ctx, answers)
   106  		}
   107  		for _, n := range cp.notifiers {
   108  			n.activate()
   109  		}
   110  	})
   111  }
   112  
   113  func (h *handler) handleMsg(msg *jsonrpcMessage) {
   114  	if ok := h.handleImmediate(msg); ok {
   115  		return
   116  	}
   117  	h.startCallProc(func(cp *callProc) {
   118  		answer := h.handleCallMsg(cp, msg)
   119  		h.addSubscriptions(cp.notifiers)
   120  		if answer != nil {
   121  			h.conn.writeJSON(cp.ctx, answer)
   122  		}
   123  		for _, n := range cp.notifiers {
   124  			n.activate()
   125  		}
   126  	})
   127  }
   128  
   129  func (h *handler) close(err error, inflightReq *requestOp) {
   130  	h.cancelAllRequests(err, inflightReq)
   131  	h.callWG.Wait()
   132  	h.cancelRoot()
   133  	h.cancelServerSubscriptions(err)
   134  }
   135  
   136  func (h *handler) addRequestOp(op *requestOp) {
   137  	for _, id := range op.ids {
   138  		h.respWait[string(id)] = op
   139  	}
   140  }
   141  
   142  func (h *handler) removeRequestOp(op *requestOp) {
   143  	for _, id := range op.ids {
   144  		delete(h.respWait, string(id))
   145  	}
   146  }
   147  
   148  func (h *handler) cancelAllRequests(err error, inflightReq *requestOp) {
   149  	didClose := make(map[*requestOp]bool)
   150  	if inflightReq != nil {
   151  		didClose[inflightReq] = true
   152  	}
   153  
   154  	for id, op := range h.respWait {
   155  		// Remove the op so that later calls will not close op.resp again.
   156  		delete(h.respWait, id)
   157  
   158  		if !didClose[op] {
   159  			op.err = err
   160  			close(op.resp)
   161  			didClose[op] = true
   162  		}
   163  	}
   164  }
   165  
   166  func (h *handler) addSubscriptions(nn []*Notifier) {
   167  	h.subLock.Lock()
   168  	defer h.subLock.Unlock()
   169  
   170  	for _, n := range nn {
   171  		if sub := n.takeSubscription(); sub != nil {
   172  			h.serverSubs[sub.ID] = sub
   173  		}
   174  	}
   175  }
   176  
   177  // cancelServerSubscriptions removes all subscriptions and closes their error channels.
   178  func (h *handler) cancelServerSubscriptions(err error) {
   179  	h.subLock.Lock()
   180  	defer h.subLock.Unlock()
   181  
   182  	for id, s := range h.serverSubs {
   183  		s.err <- err
   184  		close(s.err)
   185  		delete(h.serverSubs, id)
   186  	}
   187  }
   188  
   189  func (h *handler) startCallProc(fn func(*callProc)) {
   190  	h.callWG.Add(1)
   191  	go func() {
   192  		ctx, cancel := context.WithCancel(h.rootCtx)
   193  		defer h.callWG.Done()
   194  		defer cancel()
   195  		fn(&callProc{ctx: ctx})
   196  	}()
   197  }
   198  
   199  func (h *handler) handleImmediate(msg *jsonrpcMessage) bool {
   200  	start := time.Now()
   201  	switch {
   202  	case msg.isNotification():
   203  		if strings.HasSuffix(msg.Method, notificationMethodSuffix) {
   204  			h.handleSubscriptionResult(msg)
   205  			return true
   206  		}
   207  		return false
   208  	case msg.isResponse():
   209  		h.handleResponse(msg)
   210  		h.log.Debug("Handled RPC response", "reqid", idForLog{msg.ID}, "t", time.Since(start))
   211  		return true
   212  	default:
   213  		return false
   214  	}
   215  }
   216  
   217  func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) {
   218  	var result subscriptionResult
   219  	if err := json.Unmarshal(msg.Params, &result); err != nil {
   220  		h.log.Debug("Dropping invalid subscription message")
   221  		return
   222  	}
   223  }
   224  
   225  func (h *handler) handleResponse(msg *jsonrpcMessage) {
   226  	op := h.respWait[string(msg.ID)]
   227  	if op == nil {
   228  		h.log.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID})
   229  		return
   230  	}
   231  	delete(h.respWait, string(msg.ID))
   232  	op.resp <- msg
   233  }
   234  
   235  func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
   236  	start := time.Now()
   237  	switch {
   238  	//case msg.isNotification():
   239  	case msg.isCall():
   240  		h.log.Trace("begin "+msg.Method, "p", string(msg.Params))
   241  		resp := h.handleCall(ctx, msg)
   242  		var ctx []interface{}
   243  		ctx = append(ctx, "reqid", idForLog{msg.ID}, "t", time.Since(start), "p", string(msg.Params), "r", string(resp.Result))
   244  		if resp.Error != nil {
   245  			ctx = append(ctx, "err", resp.Error.Message)
   246  			if resp.Error.Data != nil {
   247  				ctx = append(ctx, "errdata", resp.Error.Data)
   248  			}
   249  			h.log.Warn("Served "+msg.Method, ctx...)
   250  		} else {
   251  			h.log.Trace("Served "+msg.Method, ctx...)
   252  		}
   253  		return resp
   254  	case msg.hasValidID():
   255  		return msg.errorResponse(&invalidRequestError{"invalid request"})
   256  	default:
   257  		return errorMessage(&invalidRequestError{"invalid request"})
   258  	}
   259  }
   260  
   261  func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
   262  	if msg.isSubscribe() {
   263  		return h.handleSubscribe(cp, msg)
   264  	}
   265  	var callb *callback
   266  	if msg.isUnsubscribe() {
   267  		callb = h.unsubscribeCb
   268  	} else {
   269  		callb = h.reg.callback(msg.Method)
   270  	}
   271  	if callb == nil {
   272  		return msg.errorResponse(&methodNotFoundError{method: msg.Method})
   273  	}
   274  	args, err := parsePositionalArguments(msg.Params, callb.argTypes)
   275  	if err != nil {
   276  		return msg.errorResponse(&invalidParamsError{err.Error()})
   277  	}
   278  	start := time.Now()
   279  	answer := h.runMethod(cp.ctx, msg, callb, args)
   280  
   281  	// Collect the statistics for RPC calls if metrics is enabled.
   282  	// We only care about pure rpc call. Filter out subscription.
   283  	if callb != h.unsubscribeCb {
   284  		rpcRequestGauge.Inc()
   285  		if answer != nil && answer.Error != nil {
   286  			failedReqeustGauge.Inc()
   287  		}
   288  		newRPCServingTimerMS(msg.Method, answer == nil || answer.Error == nil).UpdateDuration(start)
   289  	}
   290  	return answer
   291  }
   292  
   293  // handleSubscribe processes *_subscribe method calls.
   294  func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage {
   295  	if !h.allowSubscribe {
   296  		return msg.errorResponse(ErrNotificationsUnsupported)
   297  	}
   298  
   299  	// Subscription method name is first argument.
   300  	name, err := parseSubscriptionName(msg.Params)
   301  	if err != nil {
   302  		return msg.errorResponse(&invalidParamsError{err.Error()})
   303  	}
   304  	namespace := msg.namespace()
   305  	callb := h.reg.subscription(namespace, name)
   306  	if callb == nil {
   307  		return msg.errorResponse(&subscriptionNotFoundError{namespace, name})
   308  	}
   309  
   310  	// Parse subscription name arg too, but remove it before calling the callback.
   311  	argTypes := append([]reflect.Type{stringType}, callb.argTypes...)
   312  	args, err := parsePositionalArguments(msg.Params, argTypes)
   313  	if err != nil {
   314  		return msg.errorResponse(&invalidParamsError{err.Error()})
   315  	}
   316  	args = args[1:]
   317  
   318  	// Install notifier in context so the subscription handler can find it.
   319  	n := &Notifier{h: h, namespace: namespace}
   320  	cp.notifiers = append(cp.notifiers, n)
   321  	ctx := context.WithValue(cp.ctx, notifierKey{}, n)
   322  
   323  	return h.runMethod(ctx, msg, callb, args)
   324  }
   325  
   326  func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *callback, args []reflect.Value) *jsonrpcMessage {
   327  	result, err := callb.call(ctx, msg.Method, args)
   328  	if err != nil {
   329  		return msg.errorResponse(err)
   330  	}
   331  	return msg.response(result)
   332  }
   333  
   334  // unsubscribe is the callback function for all *_unsubscribe calls.
   335  func (h *handler) unsubscribe(ctx context.Context, id ID) (bool, error) {
   336  	h.subLock.Lock()
   337  	defer h.subLock.Unlock()
   338  
   339  	s := h.serverSubs[id]
   340  	if s == nil {
   341  		return false, ErrSubscriptionNotFound
   342  	}
   343  	close(s.err)
   344  	delete(h.serverSubs, id)
   345  	return true, nil
   346  }
   347  
   348  type idForLog struct{ json.RawMessage }
   349  
   350  func (id idForLog) String() string {
   351  	if s, err := strconv.Unquote(string(id.RawMessage)); err == nil {
   352  		return s
   353  	}
   354  	return string(id.RawMessage)
   355  }