github.com/klaytn/klaytn@v1.10.2/networks/rpc/handler.go (about) 1 // Modifications Copyright 2022 The klaytn Authors 2 // Copyright 2022 The go-ethereum Authors 3 // This file is part of the go-ethereum library. 4 // 5 // The go-ethereum library is free software: you can redistribute it and/or modify 6 // it under the terms of the GNU Lesser General Public License as published by 7 // the Free Software Foundation, either version 3 of the License, or 8 // (at your option) any later version. 9 // 10 // The go-ethereum library is distributed in the hope that it will be useful, 11 // but WITHOUT ANY WARRANTY; without even the implied warranty of 12 // MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the 13 // GNU Lesser General Public License for more details. 14 // 15 // You should have received a copy of the GNU Lesser General Public License 16 // along with the go-ethereum library. If not, see <http://www.gnu.org/licenses/>. 17 // 18 // This file is derived from rpc/handler.go (2022/08/04). 19 // Modified and improved for the klaytn development. 20 21 package rpc 22 23 import ( 24 "context" 25 "encoding/json" 26 "fmt" 27 "reflect" 28 "strconv" 29 "strings" 30 "sync" 31 "sync/atomic" 32 "time" 33 34 "github.com/klaytn/klaytn/log" 35 ) 36 37 // handler handles JSON-RPC messages. There is one handler per connection. Note that 38 // handler is not safe for concurrent use. Message handling never blocks indefinitely 39 // because RPCs are processed on background goroutines launched by handler. 40 // 41 // The entry points for incoming messages are: 42 // 43 // h.handleMsg(message) 44 // h.handleBatch(message) 45 // 46 // Outgoing calls use the requestOp struct. Register the request before sending it 47 // on the connection: 48 // 49 // op := &requestOp{ids: ...} 50 // h.addRequestOp(op) 51 // 52 // Now send the request, then wait for the reply to be delivered through handleMsg: 53 // 54 // if err := op.wait(...); err != nil { 55 // h.removeRequestOp(op) // timeout, etc. 56 // } 57 type handler struct { 58 reg *serviceRegistry 59 unsubscribeCb *callback 60 idgen func() ID // subscription ID generator 61 respWait map[string]*requestOp // active client requests 62 clientSubs map[string]*ClientSubscription // active client subscriptions 63 callWG sync.WaitGroup // pending call goroutines 64 rootCtx context.Context // canceled by close() 65 cancelRoot func() // cancel function for rootCtx 66 conn jsonWriter // where responses will be sent 67 allowSubscribe bool 68 69 subLock sync.Mutex 70 serverSubs map[ID]*Subscription 71 } 72 73 type callProc struct { 74 ctx context.Context 75 notifiers []*Notifier 76 } 77 78 func newHandler(connCtx context.Context, conn jsonWriter, idgen func() ID, reg *serviceRegistry) *handler { 79 rootCtx, cancelRoot := context.WithCancel(connCtx) 80 h := &handler{ 81 reg: reg, 82 idgen: idgen, 83 conn: conn, 84 respWait: make(map[string]*requestOp), 85 clientSubs: make(map[string]*ClientSubscription), 86 rootCtx: rootCtx, 87 cancelRoot: cancelRoot, 88 allowSubscribe: true, 89 serverSubs: make(map[ID]*Subscription), 90 } 91 h.unsubscribeCb = newCallback(reflect.Value{}, reflect.ValueOf(h.unsubscribe)) 92 return h 93 } 94 95 // handleBatch executes all messages in a batch and returns the responses. 96 func (h *handler) handleBatch(msgs []*jsonrpcMessage) { 97 // Emit error response for empty batches: 98 if len(msgs) == 0 { 99 rpcErrorResponsesCounter.Inc(1) 100 h.startCallProc(func(cp *callProc) { 101 h.conn.writeJSON(cp.ctx, errorMessage(&invalidRequestError{"empty batch"})) 102 }) 103 return 104 } 105 106 rpcTotalRequestsCounter.Inc(int64(len(msgs))) 107 108 // Handle non-call messages first: 109 calls := make([]*jsonrpcMessage, 0, len(msgs)) 110 for _, msg := range msgs { 111 if handled := h.handleImmediate(msg); !handled { 112 calls = append(calls, msg) 113 } 114 } 115 if len(calls) == 0 { 116 return 117 } 118 119 if atomic.LoadInt64(&pendingRequestCount) > pendingRequestLimit { 120 rpcErrorResponsesCounter.Inc(int64(len(calls))) 121 err := &invalidRequestError{"server requests exceed the limit"} 122 logger.Debug(fmt.Sprintf("request error %v\n", err)) 123 h.startCallProc(func(cp *callProc) { 124 h.conn.writeJSON(cp.ctx, errorMessage(err)) 125 }) 126 return 127 } 128 129 // Process calls on a goroutine because they may block indefinitely: 130 h.startCallProc(func(cp *callProc) { 131 answers := make([]*jsonrpcMessage, 0, len(msgs)) 132 for _, msg := range calls { 133 if answer := h.handleCallMsg(cp, msg); answer != nil { 134 answers = append(answers, answer) 135 } 136 } 137 h.addSubscriptions(cp.notifiers) 138 if len(answers) > 0 { 139 h.conn.writeJSON(cp.ctx, answers) 140 } 141 for _, n := range cp.notifiers { 142 n.activate() 143 } 144 }) 145 } 146 147 // handleMsg handles a single message. 148 func (h *handler) handleMsg(msg *jsonrpcMessage) { 149 rpcTotalRequestsCounter.Inc(1) 150 if ok := h.handleImmediate(msg); ok { 151 return 152 } 153 154 if atomic.LoadInt64(&pendingRequestCount) > pendingRequestLimit { 155 rpcErrorResponsesCounter.Inc(1) 156 err := &invalidRequestError{"server requests exceed the limit"} 157 logger.Debug(fmt.Sprintf("request error %v\n", err)) 158 h.startCallProc(func(cp *callProc) { 159 h.conn.writeJSON(cp.ctx, errorMessage(err)) 160 }) 161 return 162 } 163 164 h.startCallProc(func(cp *callProc) { 165 answer := h.handleCallMsg(cp, msg) 166 h.addSubscriptions(cp.notifiers) 167 if answer != nil { 168 h.conn.writeJSON(cp.ctx, answer) 169 } 170 for _, n := range cp.notifiers { 171 n.activate() 172 } 173 }) 174 } 175 176 // close cancels all requests except for inflightReq and waits for 177 // call goroutines to shut down. 178 func (h *handler) close(err error, inflightReq *requestOp) { 179 h.cancelAllRequests(err, inflightReq) 180 h.callWG.Wait() 181 h.cancelRoot() 182 h.cancelServerSubscriptions(err) 183 } 184 185 // addRequestOp registers a request operation. 186 func (h *handler) addRequestOp(op *requestOp) { 187 for _, id := range op.ids { 188 h.respWait[string(id)] = op 189 } 190 } 191 192 // removeRequestOps stops waiting for the given request IDs. 193 func (h *handler) removeRequestOp(op *requestOp) { 194 for _, id := range op.ids { 195 delete(h.respWait, string(id)) 196 } 197 } 198 199 // cancelAllRequests unblocks and removes pending requests and active subscriptions. 200 func (h *handler) cancelAllRequests(err error, inflightReq *requestOp) { 201 didClose := make(map[*requestOp]bool) 202 if inflightReq != nil { 203 didClose[inflightReq] = true 204 } 205 206 for id, op := range h.respWait { 207 // Remove the op so that later calls will not close op.resp again. 208 delete(h.respWait, id) 209 210 if !didClose[op] { 211 op.err = err 212 close(op.resp) 213 didClose[op] = true 214 } 215 } 216 for id, sub := range h.clientSubs { 217 delete(h.clientSubs, id) 218 sub.quitWithError(err, false) 219 } 220 } 221 222 func (h *handler) addSubscriptions(nn []*Notifier) { 223 h.subLock.Lock() 224 defer h.subLock.Unlock() 225 226 for _, n := range nn { 227 if sub := n.takeSubscription(); sub != nil { 228 h.serverSubs[sub.ID] = sub 229 } 230 } 231 } 232 233 // cancelServerSubscriptions removes all subscriptions and closes their error channels. 234 func (h *handler) cancelServerSubscriptions(err error) { 235 h.subLock.Lock() 236 defer h.subLock.Unlock() 237 238 for id, s := range h.serverSubs { 239 s.err <- err 240 close(s.err) 241 delete(h.serverSubs, id) 242 } 243 } 244 245 // startCallProc runs fn in a new goroutine and starts tracking it in the h.calls wait group. 246 func (h *handler) startCallProc(fn func(*callProc)) { 247 atomic.AddInt64(&pendingRequestCount, 1) 248 rpcPendingRequestsCount.Inc(1) 249 h.callWG.Add(1) 250 go func() { 251 ctx, cancel := context.WithCancel(h.rootCtx) 252 defer h.callWG.Done() 253 defer cancel() 254 defer atomic.AddInt64(&pendingRequestCount, -1) 255 fn(&callProc{ctx: ctx}) 256 }() 257 } 258 259 // handleImmediate executes non-call messages. It returns false if the message is a 260 // call or requires a reply. 261 func (h *handler) handleImmediate(msg *jsonrpcMessage) bool { 262 start := time.Now() 263 switch { 264 case msg.isNotification(): 265 if strings.HasSuffix(msg.Method, notificationMethodSuffix) { 266 h.handleSubscriptionResult(msg) 267 return true 268 } 269 return false 270 case msg.isResponse(): 271 h.handleResponse(msg) 272 logger.Trace("Handled RPC response", "reqid", idForLog{msg.ID}, "duration", time.Since(start)) 273 return true 274 default: 275 return false 276 } 277 } 278 279 // handleSubscriptionResult processes subscription notifications. 280 func (h *handler) handleSubscriptionResult(msg *jsonrpcMessage) { 281 var result subscriptionResult 282 if err := json.Unmarshal(msg.Params, &result); err != nil { 283 logger.Debug("Dropping invalid subscription message") 284 return 285 } 286 logger.Trace("rpc client Notification", "msg", log.Lazy{Fn: func() string { 287 return fmt.Sprint("<-readResp: notification ", msg) 288 }}) 289 if h.clientSubs[result.ID] != nil { 290 h.clientSubs[result.ID].deliver(result.Result) 291 } 292 } 293 294 // handleResponse processes method call responses. 295 func (h *handler) handleResponse(msg *jsonrpcMessage) { 296 op := h.respWait[string(msg.ID)] 297 if op == nil { 298 logger.Debug("Unsolicited RPC response", "reqid", idForLog{msg.ID}) 299 return 300 } 301 logger.Trace("rpc client Response", "msg", log.Lazy{Fn: func() string { 302 return fmt.Sprint("<-readResp: response ", msg) 303 }}) 304 delete(h.respWait, string(msg.ID)) 305 // For normal responses, just forward the reply to Call/BatchCall. 306 if op.sub == nil { 307 op.resp <- msg 308 return 309 } 310 // For subscription responses, start the subscription if the server 311 // indicates success. KlaySubscribe gets unblocked in either case through 312 // the op.resp channel. 313 defer close(op.resp) 314 if msg.Error != nil { 315 op.err = msg.Error 316 return 317 } 318 if op.err = json.Unmarshal(msg.Result, &op.sub.subid); op.err == nil { 319 go op.sub.start() 320 h.clientSubs[op.sub.subid] = op.sub 321 } 322 } 323 324 // handleCallMsg executes a call message and returns the answer. 325 func (h *handler) handleCallMsg(ctx *callProc, msg *jsonrpcMessage) *jsonrpcMessage { 326 start := time.Now() 327 switch { 328 case msg.isNotification(): 329 h.handleCall(ctx, msg) 330 logger.Trace("Served "+msg.Method, "duration", time.Since(start)) 331 return nil 332 case msg.isCall(): 333 resp := h.handleCall(ctx, msg) 334 if resp.Error != nil { 335 logger.Debug("Served "+msg.Method, "reqid", idForLog{msg.ID}, "duration", time.Since(start), "err", resp.Error.Message) 336 } else { 337 logger.Trace("Served "+msg.Method, "reqid", idForLog{msg.ID}, "duration", time.Since(start)) 338 } 339 return resp 340 case msg.hasValidID(): 341 rpcErrorResponsesCounter.Inc(1) 342 return msg.errorResponse(&invalidRequestError{"invalid request"}) 343 default: 344 rpcErrorResponsesCounter.Inc(1) 345 return errorMessage(&invalidRequestError{"invalid request"}) 346 } 347 } 348 349 // handleCall processes method calls. 350 func (h *handler) handleCall(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { 351 if msg.isSubscribe() { 352 return h.handleSubscribe(cp, msg) 353 } 354 var callb *callback 355 if msg.isUnsubscribe() { 356 callb = h.unsubscribeCb 357 wsUnsubscriptionReqCounter.Inc(1) 358 } else { 359 callb = h.reg.callback(msg.Method) 360 } 361 if callb == nil { 362 rpcErrorResponsesCounter.Inc(1) 363 return msg.errorResponse(&methodNotFoundError{method: msg.Method}) 364 } 365 args, err := parsePositionalArguments(msg.Params, callb.argTypes) 366 if err != nil { 367 rpcErrorResponsesCounter.Inc(1) 368 return msg.errorResponse(&invalidParamsError{err.Error()}) 369 } 370 return h.runMethod(cp.ctx, msg, callb, args) 371 } 372 373 // handleSubscribe processes *_subscribe method calls. 374 func (h *handler) handleSubscribe(cp *callProc, msg *jsonrpcMessage) *jsonrpcMessage { 375 if !h.allowSubscribe { 376 rpcErrorResponsesCounter.Inc(1) 377 return msg.errorResponse(ErrNotificationsUnsupported) 378 } 379 380 if int32(len(h.serverSubs)) >= MaxSubscriptionPerWSConn { 381 rpcErrorResponsesCounter.Inc(1) 382 return msg.errorResponse(&callbackError{ 383 fmt.Sprintf("Maximum %d subscriptions are allowed for a websocket connection. "+ 384 "The limit can be updated with 'admin_setMaxSubscriptionPerWSConn' API", MaxSubscriptionPerWSConn), 385 }) 386 } 387 388 // Subscription method name is first argument. 389 name, err := parseSubscriptionName(msg.Params) 390 if err != nil { 391 rpcErrorResponsesCounter.Inc(1) 392 return msg.errorResponse(&invalidParamsError{err.Error()}) 393 } 394 namespace := msg.namespace() 395 callb := h.reg.subscription(namespace, name) 396 if callb == nil { 397 rpcErrorResponsesCounter.Inc(1) 398 return msg.errorResponse(&subscriptionNotFoundError{namespace, name}) 399 } 400 401 // Parse subscription name arg too, but remove it before calling the callback. 402 argTypes := append([]reflect.Type{stringType}, callb.argTypes...) 403 args, err := parsePositionalArguments(msg.Params, argTypes) 404 if err != nil { 405 rpcErrorResponsesCounter.Inc(1) 406 return msg.errorResponse(&invalidParamsError{err.Error()}) 407 } 408 args = args[1:] 409 410 // Install notifier in context so the subscription handler can find it. 411 n := &Notifier{h: h, namespace: namespace} 412 cp.notifiers = append(cp.notifiers, n) 413 ctx := context.WithValue(cp.ctx, notifierKey{}, n) 414 415 wsSubscriptionReqCounter.Inc(1) 416 417 return h.runMethod(ctx, msg, callb, args) 418 } 419 420 // runMethod runs the Go callback for an RPC method. 421 func (h *handler) runMethod(ctx context.Context, msg *jsonrpcMessage, callb *callback, args []reflect.Value) *jsonrpcMessage { 422 result, err := callb.call(ctx, msg.Method, args) 423 if err != nil { 424 rpcErrorResponsesCounter.Inc(1) 425 return msg.errorResponse(err) 426 } 427 428 rpcSuccessResponsesCounter.Inc(1) 429 return msg.response(result) 430 } 431 432 // unsubscribe is the callback function for all *_unsubscribe calls. 433 func (h *handler) unsubscribe(ctx context.Context, id ID) (bool, error) { 434 h.subLock.Lock() 435 defer h.subLock.Unlock() 436 437 s := h.serverSubs[id] 438 if s == nil { 439 return false, ErrSubscriptionNotFound 440 } 441 close(s.err) 442 delete(h.serverSubs, id) 443 return true, nil 444 } 445 446 type idForLog struct{ json.RawMessage } 447 448 func (id idForLog) String() string { 449 if s, err := strconv.Unquote(string(id.RawMessage)); err == nil { 450 return s 451 } 452 return string(id.RawMessage) 453 }