go-micro.dev/v5@v5.12.0/server/rpc_router.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"errors"
     6  	"fmt"
     7  	"io"
     8  	"reflect"
     9  	"runtime/debug"
    10  	"strings"
    11  	"sync"
    12  	"unicode"
    13  	"unicode/utf8"
    14  
    15  	"go-micro.dev/v5/codec"
    16  	merrors "go-micro.dev/v5/errors"
    17  	log "go-micro.dev/v5/logger"
    18  )
    19  
    20  var (
    21  	errLastStreamResponse = errors.New("EOS")
    22  
    23  	// Precompute the reflect type for error. Can't use error directly
    24  	// because Typeof takes an empty interface value. This is annoying.
    25  	typeOfError = reflect.TypeOf((*error)(nil)).Elem()
    26  )
    27  
    28  type methodType struct {
    29  	ArgType     reflect.Type
    30  	ReplyType   reflect.Type
    31  	ContextType reflect.Type
    32  	method      reflect.Method
    33  	sync.Mutex  // protects counters
    34  	stream      bool
    35  }
    36  
    37  type service struct {
    38  	typ    reflect.Type           // type of the receiver
    39  	method map[string]*methodType // registered methods
    40  	rcvr   reflect.Value          // receiver of methods for the service
    41  	name   string                 // name of service
    42  }
    43  
    44  type request struct {
    45  	msg  *codec.Message
    46  	next *request // for free list in Server
    47  }
    48  
    49  type response struct {
    50  	msg  *codec.Message
    51  	next *response // for free list in Server
    52  }
    53  
    54  // router represents an RPC router.
    55  type router struct {
    56  	ops RouterOptions
    57  
    58  	serviceMap map[string]*service
    59  
    60  	freeReq *request
    61  
    62  	freeResp *response
    63  
    64  	subscribers map[string][]*subscriber
    65  	name        string
    66  
    67  	// handler wrappers
    68  	hdlrWrappers []HandlerWrapper
    69  	// subscriber wrappers
    70  	subWrappers []SubscriberWrapper
    71  
    72  	su sync.RWMutex
    73  
    74  	mu sync.Mutex // protects the serviceMap
    75  
    76  	reqLock sync.Mutex // protects freeReq
    77  
    78  	respLock sync.Mutex // protects freeResp
    79  }
    80  
    81  // rpcRouter encapsulates functions that become a Router.
    82  type rpcRouter struct {
    83  	h func(context.Context, Request, interface{}) error
    84  	m func(context.Context, string, Message) error
    85  }
    86  
    87  func (r rpcRouter) ProcessMessage(ctx context.Context, subscriber string, msg Message) error {
    88  	return r.m(ctx, subscriber, msg)
    89  }
    90  
    91  func (r rpcRouter) ServeRequest(ctx context.Context, req Request, rsp Response) error {
    92  	return r.h(ctx, req, rsp)
    93  }
    94  
    95  func newRpcRouter(opts ...RouterOption) *router {
    96  	return &router{
    97  		ops:         NewRouterOptions(opts...),
    98  		serviceMap:  make(map[string]*service),
    99  		subscribers: make(map[string][]*subscriber),
   100  	}
   101  }
   102  
   103  // Is this an exported - upper case - name?
   104  func isExported(name string) bool {
   105  	rune, _ := utf8.DecodeRuneInString(name)
   106  	return unicode.IsUpper(rune)
   107  }
   108  
   109  // Is this type exported or a builtin?
   110  func isExportedOrBuiltinType(t reflect.Type) bool {
   111  	for t.Kind() == reflect.Ptr {
   112  		t = t.Elem()
   113  	}
   114  	// PkgPath will be non-empty even for an exported type,
   115  	// so we need to check the type name as well.
   116  	return isExported(t.Name()) || t.PkgPath() == ""
   117  }
   118  
   119  // prepareMethod returns a methodType for the provided method or nil
   120  // in case if the method was unsuitable.
   121  func prepareMethod(method reflect.Method, logger log.Logger) *methodType {
   122  	mtype := method.Type
   123  	mname := method.Name
   124  	var replyType, argType, contextType reflect.Type
   125  	var stream bool
   126  
   127  	// Method must be exported.
   128  	if method.PkgPath != "" {
   129  		return nil
   130  	}
   131  
   132  	switch mtype.NumIn() {
   133  	case 3:
   134  		// assuming streaming
   135  		argType = mtype.In(2)
   136  		contextType = mtype.In(1)
   137  		stream = true
   138  	case 4:
   139  		// method that takes a context
   140  		argType = mtype.In(2)
   141  		replyType = mtype.In(3)
   142  		contextType = mtype.In(1)
   143  	default:
   144  		logger.Logf(log.ErrorLevel, "method %v of %v has wrong number of ins: %v", mname, mtype, mtype.NumIn())
   145  		return nil
   146  	}
   147  
   148  	if stream {
   149  		// check stream type
   150  		streamType := reflect.TypeOf((*Stream)(nil)).Elem()
   151  		if !argType.Implements(streamType) {
   152  			logger.Logf(log.ErrorLevel, "%v argument does not implement Stream interface: %v", mname, argType)
   153  			return nil
   154  		}
   155  	} else {
   156  		// if not stream check the replyType
   157  
   158  		// First arg need not be a pointer.
   159  		if !isExportedOrBuiltinType(argType) {
   160  			logger.Logf(log.ErrorLevel, "%v argument type not exported: %v", mname, argType)
   161  			return nil
   162  		}
   163  
   164  		if replyType.Kind() != reflect.Ptr {
   165  			logger.Logf(log.ErrorLevel, "method %v reply type not a pointer: %v", mname, replyType)
   166  			return nil
   167  		}
   168  
   169  		// Reply type must be exported.
   170  		if !isExportedOrBuiltinType(replyType) {
   171  			logger.Logf(log.ErrorLevel, "method %v reply type not exported: %v", mname, replyType)
   172  			return nil
   173  		}
   174  	}
   175  
   176  	// Method needs one out.
   177  	if mtype.NumOut() != 1 {
   178  		logger.Logf(log.ErrorLevel, "method %v has wrong number of outs: %v", mname, mtype.NumOut())
   179  		return nil
   180  	}
   181  
   182  	// The return type of the method must be error.
   183  	if returnType := mtype.Out(0); returnType != typeOfError {
   184  		logger.Logf(log.ErrorLevel, "method %v returns %v not error", mname, returnType.String())
   185  		return nil
   186  	}
   187  
   188  	return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
   189  }
   190  
   191  func (router *router) sendResponse(sending sync.Locker,
   192  	req *request,
   193  	reply interface{},
   194  	cc codec.Writer,
   195  	last bool) error {
   196  	msg := new(codec.Message)
   197  	msg.Type = codec.Response
   198  	resp := router.getResponse()
   199  	resp.msg = msg
   200  
   201  	resp.msg.Id = req.msg.Id
   202  
   203  	sending.Lock()
   204  	err := cc.Write(resp.msg, reply)
   205  	sending.Unlock()
   206  
   207  	router.freeResponse(resp)
   208  
   209  	return err
   210  }
   211  
   212  func (s *service) call(ctx context.Context,
   213  	router *router,
   214  	sending *sync.Mutex,
   215  	mtype *methodType,
   216  	req *request,
   217  	argv, replyv reflect.Value,
   218  	cc codec.Writer) error {
   219  	defer router.freeRequest(req)
   220  
   221  	function := mtype.method.Func
   222  	var returnValues []reflect.Value
   223  
   224  	r := &rpcRequest{
   225  		service:     req.msg.Target,
   226  		contentType: req.msg.Header["Content-Type"],
   227  		method:      req.msg.Method,
   228  		endpoint:    req.msg.Endpoint,
   229  		body:        req.msg.Body,
   230  		header:      req.msg.Header,
   231  	}
   232  
   233  	// only set if not nil
   234  	if argv.IsValid() {
   235  		r.rawBody = argv.Interface()
   236  	}
   237  
   238  	if !mtype.stream {
   239  		fn := func(ctx context.Context, req Request, rsp interface{}) error {
   240  			returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx),
   241  				reflect.ValueOf(argv.Interface()), reflect.ValueOf(rsp)})
   242  
   243  			// The return value for the method is an error.
   244  			if err := returnValues[0].Interface(); err != nil {
   245  				return err.(error)
   246  			}
   247  
   248  			return nil
   249  		}
   250  
   251  		// wrap the handler
   252  		for i := len(router.hdlrWrappers); i > 0; i-- {
   253  			fn = router.hdlrWrappers[i-1](fn)
   254  		}
   255  
   256  		// execute handler
   257  		if err := fn(ctx, r, replyv.Interface()); err != nil {
   258  			return err
   259  		}
   260  
   261  		// send response
   262  		return router.sendResponse(sending, req, replyv.Interface(), cc, true)
   263  	}
   264  
   265  	// declare a local error to see if we errored out already
   266  	// keep track of the type, to make sure we return
   267  	// the same one consistently
   268  	rawStream := &rpcStream{
   269  		context: ctx,
   270  		codec:   cc.(codec.Codec),
   271  		request: r,
   272  		id:      req.msg.Id,
   273  	}
   274  
   275  	// Invoke the method, providing a new value for the reply.
   276  	fn := func(ctx context.Context, req Request, stream interface{}) error {
   277  		returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)})
   278  
   279  		if err := returnValues[0].Interface(); err != nil {
   280  			// the function returned an error, we use that
   281  			return err.(error)
   282  		} else if serr := rawStream.Error(); serr == io.EOF || serr == io.ErrUnexpectedEOF {
   283  			return nil
   284  		} else {
   285  			// no error, we send the special EOS error
   286  			return errLastStreamResponse
   287  		}
   288  	}
   289  
   290  	// wrap the handler
   291  	for i := len(router.hdlrWrappers); i > 0; i-- {
   292  		fn = router.hdlrWrappers[i-1](fn)
   293  	}
   294  
   295  	// client.Stream request
   296  	r.stream = true
   297  
   298  	// execute handler
   299  	return fn(ctx, r, rawStream)
   300  }
   301  
   302  func (m *methodType) prepareContext(ctx context.Context) reflect.Value {
   303  	if contextv := reflect.ValueOf(ctx); contextv.IsValid() {
   304  		return contextv
   305  	}
   306  
   307  	return reflect.Zero(m.ContextType)
   308  }
   309  
   310  func (router *router) getRequest() *request {
   311  	router.reqLock.Lock()
   312  	defer router.reqLock.Unlock()
   313  
   314  	req := router.freeReq
   315  	if req == nil {
   316  		req = new(request)
   317  	} else {
   318  		router.freeReq = req.next
   319  		*req = request{}
   320  	}
   321  
   322  	return req
   323  }
   324  
   325  func (router *router) freeRequest(req *request) {
   326  	router.reqLock.Lock()
   327  	defer router.reqLock.Unlock()
   328  
   329  	req.next = router.freeReq
   330  	router.freeReq = req
   331  }
   332  
   333  func (router *router) getResponse() *response {
   334  	router.respLock.Lock()
   335  	defer router.respLock.Unlock()
   336  
   337  	resp := router.freeResp
   338  	if resp == nil {
   339  		resp = new(response)
   340  	} else {
   341  		router.freeResp = resp.next
   342  		*resp = response{}
   343  	}
   344  
   345  	return resp
   346  }
   347  
   348  func (router *router) freeResponse(resp *response) {
   349  	router.respLock.Lock()
   350  	defer router.respLock.Unlock()
   351  
   352  	resp.next = router.freeResp
   353  	router.freeResp = resp
   354  }
   355  
   356  func (router *router) readRequest(r Request) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) {
   357  	cc := r.Codec()
   358  
   359  	service, mtype, req, keepReading, err = router.readHeader(cc)
   360  	if err != nil {
   361  		if !keepReading {
   362  			return
   363  		}
   364  		// discard body
   365  		cc.ReadBody(nil)
   366  
   367  		return
   368  	}
   369  
   370  	// is it a streaming request? then we don't read the body
   371  	if mtype.stream {
   372  		if cc.(codec.Codec).String() != "grpc" {
   373  			cc.ReadBody(nil)
   374  		}
   375  		return
   376  	}
   377  
   378  	// Decode the argument value.
   379  	argIsValue := false // if true, need to indirect before calling.
   380  	if mtype.ArgType.Kind() == reflect.Ptr {
   381  		argv = reflect.New(mtype.ArgType.Elem())
   382  	} else {
   383  		argv = reflect.New(mtype.ArgType)
   384  		argIsValue = true
   385  	}
   386  
   387  	// argv guaranteed to be a pointer now.
   388  	if err = cc.ReadBody(argv.Interface()); err != nil {
   389  		return
   390  	}
   391  
   392  	if argIsValue {
   393  		argv = argv.Elem()
   394  	}
   395  
   396  	if !mtype.stream {
   397  		replyv = reflect.New(mtype.ReplyType.Elem())
   398  	}
   399  
   400  	return
   401  }
   402  
   403  func (router *router) readHeader(cc codec.Reader) (service *service, mtype *methodType, req *request, keepReading bool, err error) {
   404  	// Grab the request header.
   405  	msg := new(codec.Message)
   406  	msg.Type = codec.Request
   407  	req = router.getRequest()
   408  	req.msg = msg
   409  
   410  	err = cc.ReadHeader(msg, msg.Type)
   411  	if err != nil {
   412  		req = nil
   413  		if err == io.EOF || err == io.ErrUnexpectedEOF {
   414  			return
   415  		}
   416  		err = errors.New("rpc: router cannot decode request: " + err.Error())
   417  
   418  		return
   419  	}
   420  
   421  	// We read the header successfully. If we see an error now,
   422  	// we can still recover and move on to the next request.
   423  	keepReading = true
   424  
   425  	serviceMethod := strings.Split(req.msg.Endpoint, ".")
   426  	if len(serviceMethod) != 2 {
   427  		err = errors.New("rpc: service/endpoint request ill-formed: " + req.msg.Endpoint)
   428  		return
   429  	}
   430  
   431  	// Look up the request.
   432  	router.mu.Lock()
   433  	service = router.serviceMap[serviceMethod[0]]
   434  	router.mu.Unlock()
   435  
   436  	if service == nil {
   437  		err = errors.New("rpc: can't find service " + serviceMethod[0])
   438  		return
   439  	}
   440  
   441  	mtype = service.method[serviceMethod[1]]
   442  	if mtype == nil {
   443  		err = errors.New("rpc: can't find method " + serviceMethod[1])
   444  	}
   445  
   446  	return
   447  }
   448  
   449  func (router *router) NewHandler(h interface{}, opts ...HandlerOption) Handler {
   450  	return NewRpcHandler(h, opts...)
   451  }
   452  
   453  func (router *router) Handle(h Handler) error {
   454  	router.mu.Lock()
   455  	defer router.mu.Unlock()
   456  
   457  	if router.serviceMap == nil {
   458  		router.serviceMap = make(map[string]*service)
   459  	}
   460  
   461  	if len(h.Name()) == 0 {
   462  		return errors.New("rpc.Handle: handler has no name")
   463  	}
   464  
   465  	if !isExported(h.Name()) {
   466  		return errors.New("rpc.Handle: type " + h.Name() + " is not exported")
   467  	}
   468  
   469  	rcvr := h.Handler()
   470  	s := new(service)
   471  	s.typ = reflect.TypeOf(rcvr)
   472  	s.rcvr = reflect.ValueOf(rcvr)
   473  
   474  	// check name
   475  	if _, present := router.serviceMap[h.Name()]; present {
   476  		return errors.New("rpc.Handle: service already defined: " + h.Name())
   477  	}
   478  
   479  	s.name = h.Name()
   480  	s.method = make(map[string]*methodType)
   481  
   482  	// Install the methods
   483  	for m := 0; m < s.typ.NumMethod(); m++ {
   484  		method := s.typ.Method(m)
   485  		if mt := prepareMethod(method, router.ops.Logger); mt != nil {
   486  			s.method[method.Name] = mt
   487  		}
   488  	}
   489  
   490  	// Check there are methods
   491  	if len(s.method) == 0 {
   492  		return errors.New("rpc Register: type " + s.name + " has no exported methods of suitable type")
   493  	}
   494  
   495  	// save handler
   496  	router.serviceMap[s.name] = s
   497  
   498  	return nil
   499  }
   500  
   501  func (router *router) ServeRequest(ctx context.Context, r Request, rsp Response) error {
   502  	sending := new(sync.Mutex)
   503  	service, mtype, req, argv, replyv, keepReading, err := router.readRequest(r)
   504  	if err != nil {
   505  		if !keepReading {
   506  			return err
   507  		}
   508  		// send a response if we actually managed to read a header.
   509  		if req != nil {
   510  			router.freeRequest(req)
   511  		}
   512  
   513  		return err
   514  	}
   515  
   516  	return service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec())
   517  }
   518  
   519  func (router *router) NewSubscriber(topic string, handler interface{}, opts ...SubscriberOption) Subscriber {
   520  	return newSubscriber(topic, handler, opts...)
   521  }
   522  
   523  func (router *router) Subscribe(s Subscriber) error {
   524  	sub, ok := s.(*subscriber)
   525  	if !ok {
   526  		return fmt.Errorf("invalid subscriber: expected *subscriber")
   527  	}
   528  
   529  	if len(sub.handlers) == 0 {
   530  		return fmt.Errorf("invalid subscriber: no handler functions")
   531  	}
   532  
   533  	if err := validateSubscriber(sub); err != nil {
   534  		return err
   535  	}
   536  
   537  	router.su.Lock()
   538  	defer router.su.Unlock()
   539  
   540  	// append to subscribers
   541  	subs := router.subscribers[sub.Topic()]
   542  	subs = append(subs, sub)
   543  	router.subscribers[sub.Topic()] = subs
   544  
   545  	return nil
   546  }
   547  
   548  func (router *router) ProcessMessage(ctx context.Context, subscriber string, msg Message) (err error) {
   549  	defer func() {
   550  		// recover any panics
   551  		if r := recover(); r != nil {
   552  			router.ops.Logger.Logf(log.ErrorLevel, "panic recovered: %v", r)
   553  			router.ops.Logger.Log(log.ErrorLevel, string(debug.Stack()))
   554  			err = merrors.InternalServerError("go.micro.server", "panic recovered: %v", r)
   555  		}
   556  	}()
   557  
   558  	// get the subscribers by topic
   559  	router.su.RLock()
   560  	subs, ok := router.subscribers[subscriber]
   561  	router.su.RUnlock()
   562  	if !ok {
   563  		log.Warnf("Subscriber not found for topic %s", msg.Topic())
   564  		return nil
   565  	}
   566  
   567  	var errResults []string
   568  
   569  	// we may have multiple subscribers for the topic
   570  	for _, sub := range subs {
   571  		// we may have multiple handlers per subscriber
   572  		for i := 0; i < len(sub.handlers); i++ {
   573  			// get the handler
   574  			handler := sub.handlers[i]
   575  
   576  			var isVal bool
   577  			var req reflect.Value
   578  
   579  			// check whether the handler is a pointer
   580  			if handler.reqType.Kind() == reflect.Ptr {
   581  				req = reflect.New(handler.reqType.Elem())
   582  			} else {
   583  				req = reflect.New(handler.reqType)
   584  				isVal = true
   585  			}
   586  
   587  			// if its a value get the element
   588  			if isVal {
   589  				req = req.Elem()
   590  			}
   591  
   592  			cc := msg.Codec()
   593  
   594  			// read the header. mostly a noop
   595  			if err = cc.ReadHeader(&codec.Message{}, codec.Event); err != nil {
   596  				return err
   597  			}
   598  
   599  			// make request value a pointer, if it's not already
   600  			reqVal := req.Interface()
   601  			if req.CanAddr() {
   602  				reqVal = req.Addr().Interface()
   603  			}
   604  
   605  			// read the body into the handler request value
   606  			if err = cc.ReadBody(reqVal); err != nil {
   607  				return err
   608  			}
   609  
   610  			// create the handler which will honor the SubscriberFunc type
   611  			fn := func(ctx context.Context, msg Message) error {
   612  				var vals []reflect.Value
   613  				if sub.typ.Kind() != reflect.Func {
   614  					vals = append(vals, sub.rcvr)
   615  				}
   616  				if handler.ctxType != nil {
   617  					vals = append(vals, reflect.ValueOf(ctx))
   618  				}
   619  
   620  				// values to pass the handler
   621  				vals = append(vals, reflect.ValueOf(msg.Payload()))
   622  
   623  				// execute the actuall call of the handler
   624  				returnValues := handler.method.Call(vals)
   625  				if rerr := returnValues[0].Interface(); rerr != nil {
   626  					err = rerr.(error)
   627  				}
   628  				return err
   629  			}
   630  
   631  			// wrap with subscriber wrappers
   632  			for i := len(router.subWrappers); i > 0; i-- {
   633  				fn = router.subWrappers[i-1](fn)
   634  			}
   635  
   636  			// create new rpc message
   637  			rpcMsg := &rpcMessage{
   638  				topic:       msg.Topic(),
   639  				contentType: msg.ContentType(),
   640  				payload:     req.Interface(),
   641  				codec:       msg.(*rpcMessage).codec,
   642  				header:      msg.Header(),
   643  				body:        msg.Body(),
   644  			}
   645  
   646  			// execute the message handler
   647  			if err = fn(ctx, rpcMsg); err != nil {
   648  				errResults = append(errResults, err.Error())
   649  			}
   650  		}
   651  	}
   652  
   653  	// if no errors just return
   654  	if len(errResults) > 0 {
   655  		err = merrors.InternalServerError("go.micro.server", "subscriber error: %v", strings.Join(errResults, "\n"))
   656  	}
   657  
   658  	return err
   659  }