gitee.com/sasukebo/go-micro/v4@v4.7.1/server/rpc_router.go (about)

     1  package server
     2  
     3  // Copyright 2009 The Go Authors. All rights reserved.
     4  // Use of this source code is governed by a BSD-style
     5  // license that can be found in the LICENSE file.
     6  //
     7  // Meh, we need to get rid of this shit
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"reflect"
    15  	"runtime/debug"
    16  	"strings"
    17  	"sync"
    18  	"unicode"
    19  	"unicode/utf8"
    20  
    21  	"gitee.com/sasukebo/go-micro/v4/codec"
    22  	merrors "gitee.com/sasukebo/go-micro/v4/errors"
    23  	"gitee.com/sasukebo/go-micro/v4/logger"
    24  )
    25  
    26  var (
    27  	lastStreamResponseError = errors.New("EOS")
    28  
    29  	// Precompute the reflect type for error. Can't use error directly
    30  	// because Typeof takes an empty interface value. This is annoying.
    31  	typeOfError = reflect.TypeOf((*error)(nil)).Elem()
    32  )
    33  
    34  type methodType struct {
    35  	sync.Mutex  // protects counters
    36  	method      reflect.Method
    37  	ArgType     reflect.Type
    38  	ReplyType   reflect.Type
    39  	ContextType reflect.Type
    40  	stream      bool
    41  }
    42  
    43  type service struct {
    44  	name   string                 // name of service
    45  	rcvr   reflect.Value          // receiver of methods for the service
    46  	typ    reflect.Type           // type of the receiver
    47  	method map[string]*methodType // registered methods
    48  }
    49  
    50  type request struct {
    51  	msg  *codec.Message
    52  	next *request // for free list in Server
    53  }
    54  
    55  type response struct {
    56  	msg  *codec.Message
    57  	next *response // for free list in Server
    58  }
    59  
    60  // router represents an RPC router.
    61  type router struct {
    62  	name string
    63  
    64  	mu         sync.Mutex // protects the serviceMap
    65  	serviceMap map[string]*service
    66  
    67  	reqLock sync.Mutex // protects freeReq
    68  	freeReq *request
    69  
    70  	respLock sync.Mutex // protects freeResp
    71  	freeResp *response
    72  
    73  	// handler wrappers
    74  	hdlrWrappers []HandlerWrapper
    75  	// subscriber wrappers
    76  	subWrappers []SubscriberWrapper
    77  
    78  	su          sync.RWMutex
    79  	subscribers map[string][]*subscriber
    80  }
    81  
    82  // rpcRouter encapsulates functions that become a server.Router
    83  type rpcRouter struct {
    84  	h func(context.Context, Request, interface{}) error
    85  	m func(context.Context, Message) error
    86  }
    87  
    88  func (r rpcRouter) ProcessMessage(ctx context.Context, msg Message) error {
    89  	return r.m(ctx, msg)
    90  }
    91  
    92  func (r rpcRouter) ServeRequest(ctx context.Context, req Request, rsp Response) error {
    93  	return r.h(ctx, req, rsp)
    94  }
    95  
    96  func newRpcRouter() *router {
    97  	return &router{
    98  		serviceMap:  make(map[string]*service),
    99  		subscribers: make(map[string][]*subscriber),
   100  	}
   101  }
   102  
   103  // Is this an exported - upper case - name?
   104  func isExported(name string) bool {
   105  	rune, _ := utf8.DecodeRuneInString(name)
   106  	return unicode.IsUpper(rune)
   107  }
   108  
   109  // Is this type exported or a builtin?
   110  func isExportedOrBuiltinType(t reflect.Type) bool {
   111  	for t.Kind() == reflect.Ptr {
   112  		t = t.Elem()
   113  	}
   114  	// PkgPath will be non-empty even for an exported type,
   115  	// so we need to check the type name as well.
   116  	return isExported(t.Name()) || t.PkgPath() == ""
   117  }
   118  
   119  // prepareMethod returns a methodType for the provided method or nil
   120  // in case if the method was unsuitable.
   121  func prepareMethod(method reflect.Method) *methodType {
   122  	mtype := method.Type
   123  	mname := method.Name
   124  	var replyType, argType, contextType reflect.Type
   125  	var stream bool
   126  
   127  	// Method must be exported.
   128  	if method.PkgPath != "" {
   129  		return nil
   130  	}
   131  
   132  	switch mtype.NumIn() {
   133  	case 3:
   134  		// assuming streaming
   135  		argType = mtype.In(2)
   136  		contextType = mtype.In(1)
   137  		stream = true
   138  	case 4:
   139  		// method that takes a context
   140  		argType = mtype.In(2)
   141  		replyType = mtype.In(3)
   142  		contextType = mtype.In(1)
   143  	default:
   144  		logger.Errorf("method %v of %v has wrong number of ins: %v", mname, mtype, mtype.NumIn())
   145  		return nil
   146  	}
   147  
   148  	if stream {
   149  		// check stream type
   150  		streamType := reflect.TypeOf((*Stream)(nil)).Elem()
   151  		if !argType.Implements(streamType) {
   152  			logger.Errorf("%v argument does not implement Stream interface: %v", mname, argType)
   153  			return nil
   154  		}
   155  	} else {
   156  		// if not stream check the replyType
   157  
   158  		// First arg need not be a pointer.
   159  		if !isExportedOrBuiltinType(argType) {
   160  			logger.Errorf("%v argument type not exported: %v", mname, argType)
   161  			return nil
   162  		}
   163  
   164  		if replyType.Kind() != reflect.Ptr {
   165  			logger.Errorf("method %v reply type not a pointer: %v", mname, replyType)
   166  			return nil
   167  		}
   168  
   169  		// Reply type must be exported.
   170  		if !isExportedOrBuiltinType(replyType) {
   171  			logger.Errorf("method %v reply type not exported: %v", mname, replyType)
   172  			return nil
   173  		}
   174  	}
   175  
   176  	// Method needs one out.
   177  	if mtype.NumOut() != 1 {
   178  		logger.Errorf("method %v has wrong number of outs: %v", mname, mtype.NumOut())
   179  		return nil
   180  	}
   181  	// The return type of the method must be error.
   182  	if returnType := mtype.Out(0); returnType != typeOfError {
   183  		logger.Errorf("method %v returns %v not error", mname, returnType.String())
   184  		return nil
   185  	}
   186  	return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
   187  }
   188  
   189  func (router *router) sendResponse(sending sync.Locker, req *request, reply interface{}, cc codec.Writer, last bool) error {
   190  	msg := new(codec.Message)
   191  	msg.Type = codec.Response
   192  	resp := router.getResponse()
   193  	resp.msg = msg
   194  
   195  	resp.msg.Id = req.msg.Id
   196  	sending.Lock()
   197  	err := cc.Write(resp.msg, reply)
   198  	sending.Unlock()
   199  	router.freeResponse(resp)
   200  	return err
   201  }
   202  
   203  func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, cc codec.Writer) error {
   204  	defer router.freeRequest(req)
   205  
   206  	function := mtype.method.Func
   207  	var returnValues []reflect.Value
   208  
   209  	r := &rpcRequest{
   210  		service:     req.msg.Target,
   211  		contentType: req.msg.Header["Content-Type"],
   212  		method:      req.msg.Method,
   213  		endpoint:    req.msg.Endpoint,
   214  		body:        req.msg.Body,
   215  		header:      req.msg.Header,
   216  	}
   217  
   218  	// only set if not nil
   219  	if argv.IsValid() {
   220  		r.rawBody = argv.Interface()
   221  	}
   222  
   223  	if !mtype.stream {
   224  		fn := func(ctx context.Context, req Request, rsp interface{}) error {
   225  			returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(argv.Interface()), reflect.ValueOf(rsp)})
   226  
   227  			// The return value for the method is an error.
   228  			if err := returnValues[0].Interface(); err != nil {
   229  				return err.(error)
   230  			}
   231  
   232  			return nil
   233  		}
   234  
   235  		// wrap the handler
   236  		for i := len(router.hdlrWrappers); i > 0; i-- {
   237  			fn = router.hdlrWrappers[i-1](fn)
   238  		}
   239  
   240  		// execute handler
   241  		if err := fn(ctx, r, replyv.Interface()); err != nil {
   242  			return err
   243  		}
   244  
   245  		// send response
   246  		return router.sendResponse(sending, req, replyv.Interface(), cc, true)
   247  	}
   248  
   249  	// declare a local error to see if we errored out already
   250  	// keep track of the type, to make sure we return
   251  	// the same one consistently
   252  	rawStream := &rpcStream{
   253  		context: ctx,
   254  		codec:   cc.(codec.Codec),
   255  		request: r,
   256  		id:      req.msg.Id,
   257  	}
   258  
   259  	// Invoke the method, providing a new value for the reply.
   260  	fn := func(ctx context.Context, req Request, stream interface{}) error {
   261  		returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)})
   262  		if err := returnValues[0].Interface(); err != nil {
   263  			// the function returned an error, we use that
   264  			return err.(error)
   265  		} else if serr := rawStream.Error(); serr == io.EOF || serr == io.ErrUnexpectedEOF {
   266  			return nil
   267  		} else {
   268  			// no error, we send the special EOS error
   269  			return lastStreamResponseError
   270  		}
   271  	}
   272  
   273  	// wrap the handler
   274  	for i := len(router.hdlrWrappers); i > 0; i-- {
   275  		fn = router.hdlrWrappers[i-1](fn)
   276  	}
   277  
   278  	// client.Stream request
   279  	r.stream = true
   280  
   281  	// execute handler
   282  	return fn(ctx, r, rawStream)
   283  }
   284  
   285  func (m *methodType) prepareContext(ctx context.Context) reflect.Value {
   286  	if contextv := reflect.ValueOf(ctx); contextv.IsValid() {
   287  		return contextv
   288  	}
   289  	return reflect.Zero(m.ContextType)
   290  }
   291  
   292  func (router *router) getRequest() *request {
   293  	router.reqLock.Lock()
   294  	req := router.freeReq
   295  	if req == nil {
   296  		req = new(request)
   297  	} else {
   298  		router.freeReq = req.next
   299  		*req = request{}
   300  	}
   301  	router.reqLock.Unlock()
   302  	return req
   303  }
   304  
   305  func (router *router) freeRequest(req *request) {
   306  	router.reqLock.Lock()
   307  	req.next = router.freeReq
   308  	router.freeReq = req
   309  	router.reqLock.Unlock()
   310  }
   311  
   312  func (router *router) getResponse() *response {
   313  	router.respLock.Lock()
   314  	resp := router.freeResp
   315  	if resp == nil {
   316  		resp = new(response)
   317  	} else {
   318  		router.freeResp = resp.next
   319  		*resp = response{}
   320  	}
   321  	router.respLock.Unlock()
   322  	return resp
   323  }
   324  
   325  func (router *router) freeResponse(resp *response) {
   326  	router.respLock.Lock()
   327  	resp.next = router.freeResp
   328  	router.freeResp = resp
   329  	router.respLock.Unlock()
   330  }
   331  
   332  func (router *router) readRequest(r Request) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) {
   333  	cc := r.Codec()
   334  
   335  	service, mtype, req, keepReading, err = router.readHeader(cc)
   336  	if err != nil {
   337  		if !keepReading {
   338  			return
   339  		}
   340  		// discard body
   341  		cc.ReadBody(nil)
   342  		return
   343  	}
   344  	// is it a streaming request? then we don't read the body
   345  	if mtype.stream {
   346  		if cc.(codec.Codec).String() != "grpc" {
   347  			cc.ReadBody(nil)
   348  		}
   349  		return
   350  	}
   351  
   352  	// Decode the argument value.
   353  	argIsValue := false // if true, need to indirect before calling.
   354  	if mtype.ArgType.Kind() == reflect.Ptr {
   355  		argv = reflect.New(mtype.ArgType.Elem())
   356  	} else {
   357  		argv = reflect.New(mtype.ArgType)
   358  		argIsValue = true
   359  	}
   360  	// argv guaranteed to be a pointer now.
   361  	if err = cc.ReadBody(argv.Interface()); err != nil {
   362  		return
   363  	}
   364  	if argIsValue {
   365  		argv = argv.Elem()
   366  	}
   367  
   368  	if !mtype.stream {
   369  		replyv = reflect.New(mtype.ReplyType.Elem())
   370  	}
   371  	return
   372  }
   373  
   374  func (router *router) readHeader(cc codec.Reader) (service *service, mtype *methodType, req *request, keepReading bool, err error) {
   375  	// Grab the request header.
   376  	msg := new(codec.Message)
   377  	msg.Type = codec.Request
   378  	req = router.getRequest()
   379  	req.msg = msg
   380  
   381  	err = cc.ReadHeader(msg, msg.Type)
   382  	if err != nil {
   383  		req = nil
   384  		if err == io.EOF || err == io.ErrUnexpectedEOF {
   385  			return
   386  		}
   387  		err = errors.New("rpc: router cannot decode request: " + err.Error())
   388  		return
   389  	}
   390  
   391  	// We read the header successfully. If we see an error now,
   392  	// we can still recover and move on to the next request.
   393  	keepReading = true
   394  
   395  	serviceMethod := strings.Split(req.msg.Endpoint, ".")
   396  	if len(serviceMethod) != 2 {
   397  		err = errors.New("rpc: service/endpoint request ill-formed: " + req.msg.Endpoint)
   398  		return
   399  	}
   400  	// Look up the request.
   401  	router.mu.Lock()
   402  	service = router.serviceMap[serviceMethod[0]]
   403  	router.mu.Unlock()
   404  	if service == nil {
   405  		err = errors.New("rpc: can't find service " + serviceMethod[0])
   406  		return
   407  	}
   408  	mtype = service.method[serviceMethod[1]]
   409  	if mtype == nil {
   410  		err = errors.New("rpc: can't find method " + serviceMethod[1])
   411  	}
   412  	return
   413  }
   414  
   415  func (router *router) NewHandler(h interface{}, opts ...HandlerOption) Handler {
   416  	return newRpcHandler(h, opts...)
   417  }
   418  
   419  func (router *router) Handle(h Handler) error {
   420  	router.mu.Lock()
   421  	defer router.mu.Unlock()
   422  	if router.serviceMap == nil {
   423  		router.serviceMap = make(map[string]*service)
   424  	}
   425  
   426  	if len(h.Name()) == 0 {
   427  		return errors.New("rpc.Handle: handler has no name")
   428  	}
   429  	if !isExported(h.Name()) {
   430  		return errors.New("rpc.Handle: type " + h.Name() + " is not exported")
   431  	}
   432  
   433  	rcvr := h.Handler()
   434  	s := new(service)
   435  	s.typ = reflect.TypeOf(rcvr)
   436  	s.rcvr = reflect.ValueOf(rcvr)
   437  
   438  	// check name
   439  	if _, present := router.serviceMap[h.Name()]; present {
   440  		return errors.New("rpc.Handle: service already defined: " + h.Name())
   441  	}
   442  
   443  	s.name = h.Name()
   444  	s.method = make(map[string]*methodType)
   445  
   446  	// Install the methods
   447  	for m := 0; m < s.typ.NumMethod(); m++ {
   448  		method := s.typ.Method(m)
   449  		if mt := prepareMethod(method); mt != nil {
   450  			s.method[method.Name] = mt
   451  		}
   452  	}
   453  
   454  	// Check there are methods
   455  	if len(s.method) == 0 {
   456  		return errors.New("rpc Register: type " + s.name + " has no exported methods of suitable type")
   457  	}
   458  
   459  	// save handler
   460  	router.serviceMap[s.name] = s
   461  	return nil
   462  }
   463  
   464  func (router *router) ServeRequest(ctx context.Context, r Request, rsp Response) error {
   465  	sending := new(sync.Mutex)
   466  	service, mtype, req, argv, replyv, keepReading, err := router.readRequest(r)
   467  	if err != nil {
   468  		if !keepReading {
   469  			return err
   470  		}
   471  		// send a response if we actually managed to read a header.
   472  		if req != nil {
   473  			router.freeRequest(req)
   474  		}
   475  		return err
   476  	}
   477  	return service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec())
   478  }
   479  
   480  func (router *router) NewSubscriber(topic string, handler interface{}, opts ...SubscriberOption) Subscriber {
   481  	return newSubscriber(topic, handler, opts...)
   482  }
   483  
   484  func (router *router) Subscribe(s Subscriber) error {
   485  	sub, ok := s.(*subscriber)
   486  	if !ok {
   487  		return fmt.Errorf("invalid subscriber: expected *subscriber")
   488  	}
   489  	if len(sub.handlers) == 0 {
   490  		return fmt.Errorf("invalid subscriber: no handler functions")
   491  	}
   492  
   493  	if err := validateSubscriber(sub); err != nil {
   494  		return err
   495  	}
   496  
   497  	router.su.Lock()
   498  	defer router.su.Unlock()
   499  
   500  	// append to subscribers
   501  	subs := router.subscribers[sub.Topic()]
   502  	subs = append(subs, sub)
   503  	router.subscribers[sub.Topic()] = subs
   504  
   505  	return nil
   506  }
   507  
   508  func (router *router) ProcessMessage(ctx context.Context, msg Message) (err error) {
   509  	defer func() {
   510  		// recover any panics
   511  		if r := recover(); r != nil {
   512  			logger.Errorf("panic recovered: %v", r)
   513  			logger.Error(string(debug.Stack()))
   514  			err = merrors.InternalServerError("go.micro.server", "panic recovered: %v", r)
   515  		}
   516  	}()
   517  
   518  	router.su.RLock()
   519  	// get the subscribers by topic
   520  	subs, ok := router.subscribers[msg.Topic()]
   521  	// unlock since we only need to get the subs
   522  	router.su.RUnlock()
   523  	if !ok {
   524  		return nil
   525  	}
   526  
   527  	var errResults []string
   528  
   529  	// we may have multiple subscribers for the topic
   530  	for _, sub := range subs {
   531  		// we may have multiple handlers per subscriber
   532  		for i := 0; i < len(sub.handlers); i++ {
   533  			// get the handler
   534  			handler := sub.handlers[i]
   535  
   536  			var isVal bool
   537  			var req reflect.Value
   538  
   539  			// check whether the handler is a pointer
   540  			if handler.reqType.Kind() == reflect.Ptr {
   541  				req = reflect.New(handler.reqType.Elem())
   542  			} else {
   543  				req = reflect.New(handler.reqType)
   544  				isVal = true
   545  			}
   546  
   547  			// if its a value get the element
   548  			if isVal {
   549  				req = req.Elem()
   550  			}
   551  
   552  			cc := msg.Codec()
   553  
   554  			// read the header. mostly a noop
   555  			if err = cc.ReadHeader(&codec.Message{}, codec.Event); err != nil {
   556  				return err
   557  			}
   558  
   559  			// read the body into the handler request value
   560  			if err = cc.ReadBody(req.Addr().Interface()); err != nil {
   561  				return err
   562  			}
   563  
   564  			// create the handler which will honour the SubscriberFunc type
   565  			fn := func(ctx context.Context, msg Message) error {
   566  				var vals []reflect.Value
   567  				if sub.typ.Kind() != reflect.Func {
   568  					vals = append(vals, sub.rcvr)
   569  				}
   570  				if handler.ctxType != nil {
   571  					vals = append(vals, reflect.ValueOf(ctx))
   572  				}
   573  
   574  				// values to pass the handler
   575  				vals = append(vals, reflect.ValueOf(msg.Payload()))
   576  
   577  				// execute the actuall call of the handler
   578  				returnValues := handler.method.Call(vals)
   579  				if rerr := returnValues[0].Interface(); rerr != nil {
   580  					err = rerr.(error)
   581  				}
   582  				return err
   583  			}
   584  
   585  			// wrap with subscriber wrappers
   586  			for i := len(router.subWrappers); i > 0; i-- {
   587  				fn = router.subWrappers[i-1](fn)
   588  			}
   589  
   590  			// create new rpc message
   591  			rpcMsg := &rpcMessage{
   592  				topic:       msg.Topic(),
   593  				contentType: msg.ContentType(),
   594  				payload:     req.Interface(),
   595  				codec:       msg.(*rpcMessage).codec,
   596  				header:      msg.Header(),
   597  				body:        msg.Body(),
   598  			}
   599  
   600  			// execute the message handler
   601  			if err = fn(ctx, rpcMsg); err != nil {
   602  				errResults = append(errResults, err.Error())
   603  			}
   604  		}
   605  	}
   606  
   607  	// if no errors just return
   608  	if len(errResults) > 0 {
   609  		err = merrors.InternalServerError("go.micro.server", "subscriber error: %v", strings.Join(errResults, "\n"))
   610  	}
   611  
   612  	return err
   613  }