go-micro.dev/v5@v5.12.0/server/rpc_router.go (about) 1 package server 2 3 import ( 4 "context" 5 "errors" 6 "fmt" 7 "io" 8 "reflect" 9 "runtime/debug" 10 "strings" 11 "sync" 12 "unicode" 13 "unicode/utf8" 14 15 "go-micro.dev/v5/codec" 16 merrors "go-micro.dev/v5/errors" 17 log "go-micro.dev/v5/logger" 18 ) 19 20 var ( 21 errLastStreamResponse = errors.New("EOS") 22 23 // Precompute the reflect type for error. Can't use error directly 24 // because Typeof takes an empty interface value. This is annoying. 25 typeOfError = reflect.TypeOf((*error)(nil)).Elem() 26 ) 27 28 type methodType struct { 29 ArgType reflect.Type 30 ReplyType reflect.Type 31 ContextType reflect.Type 32 method reflect.Method 33 sync.Mutex // protects counters 34 stream bool 35 } 36 37 type service struct { 38 typ reflect.Type // type of the receiver 39 method map[string]*methodType // registered methods 40 rcvr reflect.Value // receiver of methods for the service 41 name string // name of service 42 } 43 44 type request struct { 45 msg *codec.Message 46 next *request // for free list in Server 47 } 48 49 type response struct { 50 msg *codec.Message 51 next *response // for free list in Server 52 } 53 54 // router represents an RPC router. 55 type router struct { 56 ops RouterOptions 57 58 serviceMap map[string]*service 59 60 freeReq *request 61 62 freeResp *response 63 64 subscribers map[string][]*subscriber 65 name string 66 67 // handler wrappers 68 hdlrWrappers []HandlerWrapper 69 // subscriber wrappers 70 subWrappers []SubscriberWrapper 71 72 su sync.RWMutex 73 74 mu sync.Mutex // protects the serviceMap 75 76 reqLock sync.Mutex // protects freeReq 77 78 respLock sync.Mutex // protects freeResp 79 } 80 81 // rpcRouter encapsulates functions that become a Router. 82 type rpcRouter struct { 83 h func(context.Context, Request, interface{}) error 84 m func(context.Context, string, Message) error 85 } 86 87 func (r rpcRouter) ProcessMessage(ctx context.Context, subscriber string, msg Message) error { 88 return r.m(ctx, subscriber, msg) 89 } 90 91 func (r rpcRouter) ServeRequest(ctx context.Context, req Request, rsp Response) error { 92 return r.h(ctx, req, rsp) 93 } 94 95 func newRpcRouter(opts ...RouterOption) *router { 96 return &router{ 97 ops: NewRouterOptions(opts...), 98 serviceMap: make(map[string]*service), 99 subscribers: make(map[string][]*subscriber), 100 } 101 } 102 103 // Is this an exported - upper case - name? 104 func isExported(name string) bool { 105 rune, _ := utf8.DecodeRuneInString(name) 106 return unicode.IsUpper(rune) 107 } 108 109 // Is this type exported or a builtin? 110 func isExportedOrBuiltinType(t reflect.Type) bool { 111 for t.Kind() == reflect.Ptr { 112 t = t.Elem() 113 } 114 // PkgPath will be non-empty even for an exported type, 115 // so we need to check the type name as well. 116 return isExported(t.Name()) || t.PkgPath() == "" 117 } 118 119 // prepareMethod returns a methodType for the provided method or nil 120 // in case if the method was unsuitable. 121 func prepareMethod(method reflect.Method, logger log.Logger) *methodType { 122 mtype := method.Type 123 mname := method.Name 124 var replyType, argType, contextType reflect.Type 125 var stream bool 126 127 // Method must be exported. 128 if method.PkgPath != "" { 129 return nil 130 } 131 132 switch mtype.NumIn() { 133 case 3: 134 // assuming streaming 135 argType = mtype.In(2) 136 contextType = mtype.In(1) 137 stream = true 138 case 4: 139 // method that takes a context 140 argType = mtype.In(2) 141 replyType = mtype.In(3) 142 contextType = mtype.In(1) 143 default: 144 logger.Logf(log.ErrorLevel, "method %v of %v has wrong number of ins: %v", mname, mtype, mtype.NumIn()) 145 return nil 146 } 147 148 if stream { 149 // check stream type 150 streamType := reflect.TypeOf((*Stream)(nil)).Elem() 151 if !argType.Implements(streamType) { 152 logger.Logf(log.ErrorLevel, "%v argument does not implement Stream interface: %v", mname, argType) 153 return nil 154 } 155 } else { 156 // if not stream check the replyType 157 158 // First arg need not be a pointer. 159 if !isExportedOrBuiltinType(argType) { 160 logger.Logf(log.ErrorLevel, "%v argument type not exported: %v", mname, argType) 161 return nil 162 } 163 164 if replyType.Kind() != reflect.Ptr { 165 logger.Logf(log.ErrorLevel, "method %v reply type not a pointer: %v", mname, replyType) 166 return nil 167 } 168 169 // Reply type must be exported. 170 if !isExportedOrBuiltinType(replyType) { 171 logger.Logf(log.ErrorLevel, "method %v reply type not exported: %v", mname, replyType) 172 return nil 173 } 174 } 175 176 // Method needs one out. 177 if mtype.NumOut() != 1 { 178 logger.Logf(log.ErrorLevel, "method %v has wrong number of outs: %v", mname, mtype.NumOut()) 179 return nil 180 } 181 182 // The return type of the method must be error. 183 if returnType := mtype.Out(0); returnType != typeOfError { 184 logger.Logf(log.ErrorLevel, "method %v returns %v not error", mname, returnType.String()) 185 return nil 186 } 187 188 return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream} 189 } 190 191 func (router *router) sendResponse(sending sync.Locker, 192 req *request, 193 reply interface{}, 194 cc codec.Writer, 195 last bool) error { 196 msg := new(codec.Message) 197 msg.Type = codec.Response 198 resp := router.getResponse() 199 resp.msg = msg 200 201 resp.msg.Id = req.msg.Id 202 203 sending.Lock() 204 err := cc.Write(resp.msg, reply) 205 sending.Unlock() 206 207 router.freeResponse(resp) 208 209 return err 210 } 211 212 func (s *service) call(ctx context.Context, 213 router *router, 214 sending *sync.Mutex, 215 mtype *methodType, 216 req *request, 217 argv, replyv reflect.Value, 218 cc codec.Writer) error { 219 defer router.freeRequest(req) 220 221 function := mtype.method.Func 222 var returnValues []reflect.Value 223 224 r := &rpcRequest{ 225 service: req.msg.Target, 226 contentType: req.msg.Header["Content-Type"], 227 method: req.msg.Method, 228 endpoint: req.msg.Endpoint, 229 body: req.msg.Body, 230 header: req.msg.Header, 231 } 232 233 // only set if not nil 234 if argv.IsValid() { 235 r.rawBody = argv.Interface() 236 } 237 238 if !mtype.stream { 239 fn := func(ctx context.Context, req Request, rsp interface{}) error { 240 returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), 241 reflect.ValueOf(argv.Interface()), reflect.ValueOf(rsp)}) 242 243 // The return value for the method is an error. 244 if err := returnValues[0].Interface(); err != nil { 245 return err.(error) 246 } 247 248 return nil 249 } 250 251 // wrap the handler 252 for i := len(router.hdlrWrappers); i > 0; i-- { 253 fn = router.hdlrWrappers[i-1](fn) 254 } 255 256 // execute handler 257 if err := fn(ctx, r, replyv.Interface()); err != nil { 258 return err 259 } 260 261 // send response 262 return router.sendResponse(sending, req, replyv.Interface(), cc, true) 263 } 264 265 // declare a local error to see if we errored out already 266 // keep track of the type, to make sure we return 267 // the same one consistently 268 rawStream := &rpcStream{ 269 context: ctx, 270 codec: cc.(codec.Codec), 271 request: r, 272 id: req.msg.Id, 273 } 274 275 // Invoke the method, providing a new value for the reply. 276 fn := func(ctx context.Context, req Request, stream interface{}) error { 277 returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)}) 278 279 if err := returnValues[0].Interface(); err != nil { 280 // the function returned an error, we use that 281 return err.(error) 282 } else if serr := rawStream.Error(); serr == io.EOF || serr == io.ErrUnexpectedEOF { 283 return nil 284 } else { 285 // no error, we send the special EOS error 286 return errLastStreamResponse 287 } 288 } 289 290 // wrap the handler 291 for i := len(router.hdlrWrappers); i > 0; i-- { 292 fn = router.hdlrWrappers[i-1](fn) 293 } 294 295 // client.Stream request 296 r.stream = true 297 298 // execute handler 299 return fn(ctx, r, rawStream) 300 } 301 302 func (m *methodType) prepareContext(ctx context.Context) reflect.Value { 303 if contextv := reflect.ValueOf(ctx); contextv.IsValid() { 304 return contextv 305 } 306 307 return reflect.Zero(m.ContextType) 308 } 309 310 func (router *router) getRequest() *request { 311 router.reqLock.Lock() 312 defer router.reqLock.Unlock() 313 314 req := router.freeReq 315 if req == nil { 316 req = new(request) 317 } else { 318 router.freeReq = req.next 319 *req = request{} 320 } 321 322 return req 323 } 324 325 func (router *router) freeRequest(req *request) { 326 router.reqLock.Lock() 327 defer router.reqLock.Unlock() 328 329 req.next = router.freeReq 330 router.freeReq = req 331 } 332 333 func (router *router) getResponse() *response { 334 router.respLock.Lock() 335 defer router.respLock.Unlock() 336 337 resp := router.freeResp 338 if resp == nil { 339 resp = new(response) 340 } else { 341 router.freeResp = resp.next 342 *resp = response{} 343 } 344 345 return resp 346 } 347 348 func (router *router) freeResponse(resp *response) { 349 router.respLock.Lock() 350 defer router.respLock.Unlock() 351 352 resp.next = router.freeResp 353 router.freeResp = resp 354 } 355 356 func (router *router) readRequest(r Request) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) { 357 cc := r.Codec() 358 359 service, mtype, req, keepReading, err = router.readHeader(cc) 360 if err != nil { 361 if !keepReading { 362 return 363 } 364 // discard body 365 cc.ReadBody(nil) 366 367 return 368 } 369 370 // is it a streaming request? then we don't read the body 371 if mtype.stream { 372 if cc.(codec.Codec).String() != "grpc" { 373 cc.ReadBody(nil) 374 } 375 return 376 } 377 378 // Decode the argument value. 379 argIsValue := false // if true, need to indirect before calling. 380 if mtype.ArgType.Kind() == reflect.Ptr { 381 argv = reflect.New(mtype.ArgType.Elem()) 382 } else { 383 argv = reflect.New(mtype.ArgType) 384 argIsValue = true 385 } 386 387 // argv guaranteed to be a pointer now. 388 if err = cc.ReadBody(argv.Interface()); err != nil { 389 return 390 } 391 392 if argIsValue { 393 argv = argv.Elem() 394 } 395 396 if !mtype.stream { 397 replyv = reflect.New(mtype.ReplyType.Elem()) 398 } 399 400 return 401 } 402 403 func (router *router) readHeader(cc codec.Reader) (service *service, mtype *methodType, req *request, keepReading bool, err error) { 404 // Grab the request header. 405 msg := new(codec.Message) 406 msg.Type = codec.Request 407 req = router.getRequest() 408 req.msg = msg 409 410 err = cc.ReadHeader(msg, msg.Type) 411 if err != nil { 412 req = nil 413 if err == io.EOF || err == io.ErrUnexpectedEOF { 414 return 415 } 416 err = errors.New("rpc: router cannot decode request: " + err.Error()) 417 418 return 419 } 420 421 // We read the header successfully. If we see an error now, 422 // we can still recover and move on to the next request. 423 keepReading = true 424 425 serviceMethod := strings.Split(req.msg.Endpoint, ".") 426 if len(serviceMethod) != 2 { 427 err = errors.New("rpc: service/endpoint request ill-formed: " + req.msg.Endpoint) 428 return 429 } 430 431 // Look up the request. 432 router.mu.Lock() 433 service = router.serviceMap[serviceMethod[0]] 434 router.mu.Unlock() 435 436 if service == nil { 437 err = errors.New("rpc: can't find service " + serviceMethod[0]) 438 return 439 } 440 441 mtype = service.method[serviceMethod[1]] 442 if mtype == nil { 443 err = errors.New("rpc: can't find method " + serviceMethod[1]) 444 } 445 446 return 447 } 448 449 func (router *router) NewHandler(h interface{}, opts ...HandlerOption) Handler { 450 return NewRpcHandler(h, opts...) 451 } 452 453 func (router *router) Handle(h Handler) error { 454 router.mu.Lock() 455 defer router.mu.Unlock() 456 457 if router.serviceMap == nil { 458 router.serviceMap = make(map[string]*service) 459 } 460 461 if len(h.Name()) == 0 { 462 return errors.New("rpc.Handle: handler has no name") 463 } 464 465 if !isExported(h.Name()) { 466 return errors.New("rpc.Handle: type " + h.Name() + " is not exported") 467 } 468 469 rcvr := h.Handler() 470 s := new(service) 471 s.typ = reflect.TypeOf(rcvr) 472 s.rcvr = reflect.ValueOf(rcvr) 473 474 // check name 475 if _, present := router.serviceMap[h.Name()]; present { 476 return errors.New("rpc.Handle: service already defined: " + h.Name()) 477 } 478 479 s.name = h.Name() 480 s.method = make(map[string]*methodType) 481 482 // Install the methods 483 for m := 0; m < s.typ.NumMethod(); m++ { 484 method := s.typ.Method(m) 485 if mt := prepareMethod(method, router.ops.Logger); mt != nil { 486 s.method[method.Name] = mt 487 } 488 } 489 490 // Check there are methods 491 if len(s.method) == 0 { 492 return errors.New("rpc Register: type " + s.name + " has no exported methods of suitable type") 493 } 494 495 // save handler 496 router.serviceMap[s.name] = s 497 498 return nil 499 } 500 501 func (router *router) ServeRequest(ctx context.Context, r Request, rsp Response) error { 502 sending := new(sync.Mutex) 503 service, mtype, req, argv, replyv, keepReading, err := router.readRequest(r) 504 if err != nil { 505 if !keepReading { 506 return err 507 } 508 // send a response if we actually managed to read a header. 509 if req != nil { 510 router.freeRequest(req) 511 } 512 513 return err 514 } 515 516 return service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec()) 517 } 518 519 func (router *router) NewSubscriber(topic string, handler interface{}, opts ...SubscriberOption) Subscriber { 520 return newSubscriber(topic, handler, opts...) 521 } 522 523 func (router *router) Subscribe(s Subscriber) error { 524 sub, ok := s.(*subscriber) 525 if !ok { 526 return fmt.Errorf("invalid subscriber: expected *subscriber") 527 } 528 529 if len(sub.handlers) == 0 { 530 return fmt.Errorf("invalid subscriber: no handler functions") 531 } 532 533 if err := validateSubscriber(sub); err != nil { 534 return err 535 } 536 537 router.su.Lock() 538 defer router.su.Unlock() 539 540 // append to subscribers 541 subs := router.subscribers[sub.Topic()] 542 subs = append(subs, sub) 543 router.subscribers[sub.Topic()] = subs 544 545 return nil 546 } 547 548 func (router *router) ProcessMessage(ctx context.Context, subscriber string, msg Message) (err error) { 549 defer func() { 550 // recover any panics 551 if r := recover(); r != nil { 552 router.ops.Logger.Logf(log.ErrorLevel, "panic recovered: %v", r) 553 router.ops.Logger.Log(log.ErrorLevel, string(debug.Stack())) 554 err = merrors.InternalServerError("go.micro.server", "panic recovered: %v", r) 555 } 556 }() 557 558 // get the subscribers by topic 559 router.su.RLock() 560 subs, ok := router.subscribers[subscriber] 561 router.su.RUnlock() 562 if !ok { 563 log.Warnf("Subscriber not found for topic %s", msg.Topic()) 564 return nil 565 } 566 567 var errResults []string 568 569 // we may have multiple subscribers for the topic 570 for _, sub := range subs { 571 // we may have multiple handlers per subscriber 572 for i := 0; i < len(sub.handlers); i++ { 573 // get the handler 574 handler := sub.handlers[i] 575 576 var isVal bool 577 var req reflect.Value 578 579 // check whether the handler is a pointer 580 if handler.reqType.Kind() == reflect.Ptr { 581 req = reflect.New(handler.reqType.Elem()) 582 } else { 583 req = reflect.New(handler.reqType) 584 isVal = true 585 } 586 587 // if its a value get the element 588 if isVal { 589 req = req.Elem() 590 } 591 592 cc := msg.Codec() 593 594 // read the header. mostly a noop 595 if err = cc.ReadHeader(&codec.Message{}, codec.Event); err != nil { 596 return err 597 } 598 599 // make request value a pointer, if it's not already 600 reqVal := req.Interface() 601 if req.CanAddr() { 602 reqVal = req.Addr().Interface() 603 } 604 605 // read the body into the handler request value 606 if err = cc.ReadBody(reqVal); err != nil { 607 return err 608 } 609 610 // create the handler which will honor the SubscriberFunc type 611 fn := func(ctx context.Context, msg Message) error { 612 var vals []reflect.Value 613 if sub.typ.Kind() != reflect.Func { 614 vals = append(vals, sub.rcvr) 615 } 616 if handler.ctxType != nil { 617 vals = append(vals, reflect.ValueOf(ctx)) 618 } 619 620 // values to pass the handler 621 vals = append(vals, reflect.ValueOf(msg.Payload())) 622 623 // execute the actuall call of the handler 624 returnValues := handler.method.Call(vals) 625 if rerr := returnValues[0].Interface(); rerr != nil { 626 err = rerr.(error) 627 } 628 return err 629 } 630 631 // wrap with subscriber wrappers 632 for i := len(router.subWrappers); i > 0; i-- { 633 fn = router.subWrappers[i-1](fn) 634 } 635 636 // create new rpc message 637 rpcMsg := &rpcMessage{ 638 topic: msg.Topic(), 639 contentType: msg.ContentType(), 640 payload: req.Interface(), 641 codec: msg.(*rpcMessage).codec, 642 header: msg.Header(), 643 body: msg.Body(), 644 } 645 646 // execute the message handler 647 if err = fn(ctx, rpcMsg); err != nil { 648 errResults = append(errResults, err.Error()) 649 } 650 } 651 } 652 653 // if no errors just return 654 if len(errResults) > 0 { 655 err = merrors.InternalServerError("go.micro.server", "subscriber error: %v", strings.Join(errResults, "\n")) 656 } 657 658 return err 659 }