github.com/micro/go-micro/v2@v2.9.1/server/rpc_router.go (about)

     1  package server
     2  
     3  // Copyright 2009 The Go Authors. All rights reserved.
     4  // Use of this source code is governed by a BSD-style
     5  // license that can be found in the LICENSE file.
     6  //
     7  // Meh, we need to get rid of this shit
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"reflect"
    15  	"runtime/debug"
    16  	"strings"
    17  	"sync"
    18  	"unicode"
    19  	"unicode/utf8"
    20  
    21  	"github.com/micro/go-micro/v2/codec"
    22  	merrors "github.com/micro/go-micro/v2/errors"
    23  )
    24  
    25  var (
    26  	lastStreamResponseError = errors.New("EOS")
    27  
    28  	// Precompute the reflect type for error. Can't use error directly
    29  	// because Typeof takes an empty interface value. This is annoying.
    30  	typeOfError = reflect.TypeOf((*error)(nil)).Elem()
    31  )
    32  
    33  type methodType struct {
    34  	sync.Mutex  // protects counters
    35  	method      reflect.Method
    36  	ArgType     reflect.Type
    37  	ReplyType   reflect.Type
    38  	ContextType reflect.Type
    39  	stream      bool
    40  }
    41  
    42  type service struct {
    43  	name   string                 // name of service
    44  	rcvr   reflect.Value          // receiver of methods for the service
    45  	typ    reflect.Type           // type of the receiver
    46  	method map[string]*methodType // registered methods
    47  }
    48  
    49  type request struct {
    50  	msg  *codec.Message
    51  	next *request // for free list in Server
    52  }
    53  
    54  type response struct {
    55  	msg  *codec.Message
    56  	next *response // for free list in Server
    57  }
    58  
    59  // router represents an RPC router.
    60  type router struct {
    61  	name string
    62  
    63  	mu         sync.Mutex // protects the serviceMap
    64  	serviceMap map[string]*service
    65  
    66  	reqLock sync.Mutex // protects freeReq
    67  	freeReq *request
    68  
    69  	respLock sync.Mutex // protects freeResp
    70  	freeResp *response
    71  
    72  	// handler wrappers
    73  	hdlrWrappers []HandlerWrapper
    74  	// subscriber wrappers
    75  	subWrappers []SubscriberWrapper
    76  
    77  	su          sync.RWMutex
    78  	subscribers map[string][]*subscriber
    79  }
    80  
    81  // rpcRouter encapsulates functions that become a server.Router
    82  type rpcRouter struct {
    83  	h func(context.Context, Request, interface{}) error
    84  	m func(context.Context, Message) error
    85  }
    86  
    87  func (r rpcRouter) ProcessMessage(ctx context.Context, msg Message) error {
    88  	return r.m(ctx, msg)
    89  }
    90  
    91  func (r rpcRouter) ServeRequest(ctx context.Context, req Request, rsp Response) error {
    92  	return r.h(ctx, req, rsp)
    93  }
    94  
    95  func newRpcRouter() *router {
    96  	return &router{
    97  		serviceMap:  make(map[string]*service),
    98  		subscribers: make(map[string][]*subscriber),
    99  	}
   100  }
   101  
   102  // Is this an exported - upper case - name?
   103  func isExported(name string) bool {
   104  	rune, _ := utf8.DecodeRuneInString(name)
   105  	return unicode.IsUpper(rune)
   106  }
   107  
   108  // Is this type exported or a builtin?
   109  func isExportedOrBuiltinType(t reflect.Type) bool {
   110  	for t.Kind() == reflect.Ptr {
   111  		t = t.Elem()
   112  	}
   113  	// PkgPath will be non-empty even for an exported type,
   114  	// so we need to check the type name as well.
   115  	return isExported(t.Name()) || t.PkgPath() == ""
   116  }
   117  
   118  // prepareMethod returns a methodType for the provided method or nil
   119  // in case if the method was unsuitable.
   120  func prepareMethod(method reflect.Method) *methodType {
   121  	mtype := method.Type
   122  	mname := method.Name
   123  	var replyType, argType, contextType reflect.Type
   124  	var stream bool
   125  
   126  	// Method must be exported.
   127  	if method.PkgPath != "" {
   128  		return nil
   129  	}
   130  
   131  	switch mtype.NumIn() {
   132  	case 3:
   133  		// assuming streaming
   134  		argType = mtype.In(2)
   135  		contextType = mtype.In(1)
   136  		stream = true
   137  	case 4:
   138  		// method that takes a context
   139  		argType = mtype.In(2)
   140  		replyType = mtype.In(3)
   141  		contextType = mtype.In(1)
   142  	default:
   143  		log.Errorf("method %v of %v has wrong number of ins: %v", mname, mtype, mtype.NumIn())
   144  		return nil
   145  	}
   146  
   147  	if stream {
   148  		// check stream type
   149  		streamType := reflect.TypeOf((*Stream)(nil)).Elem()
   150  		if !argType.Implements(streamType) {
   151  			log.Errorf("%v argument does not implement Stream interface: %v", mname, argType)
   152  			return nil
   153  		}
   154  	} else {
   155  		// if not stream check the replyType
   156  
   157  		// First arg need not be a pointer.
   158  		if !isExportedOrBuiltinType(argType) {
   159  			log.Errorf("%v argument type not exported: %v", mname, argType)
   160  			return nil
   161  		}
   162  
   163  		if replyType.Kind() != reflect.Ptr {
   164  			log.Errorf("method %v reply type not a pointer: %v", mname, replyType)
   165  			return nil
   166  		}
   167  
   168  		// Reply type must be exported.
   169  		if !isExportedOrBuiltinType(replyType) {
   170  			log.Errorf("method %v reply type not exported: %v", mname, replyType)
   171  			return nil
   172  		}
   173  	}
   174  
   175  	// Method needs one out.
   176  	if mtype.NumOut() != 1 {
   177  		log.Errorf("method %v has wrong number of outs: %v", mname, mtype.NumOut())
   178  		return nil
   179  	}
   180  	// The return type of the method must be error.
   181  	if returnType := mtype.Out(0); returnType != typeOfError {
   182  		log.Errorf("method %v returns %v not error", mname, returnType.String())
   183  		return nil
   184  	}
   185  	return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
   186  }
   187  
   188  func (router *router) sendResponse(sending sync.Locker, req *request, reply interface{}, cc codec.Writer, last bool) error {
   189  	msg := new(codec.Message)
   190  	msg.Type = codec.Response
   191  	resp := router.getResponse()
   192  	resp.msg = msg
   193  
   194  	resp.msg.Id = req.msg.Id
   195  	sending.Lock()
   196  	err := cc.Write(resp.msg, reply)
   197  	sending.Unlock()
   198  	router.freeResponse(resp)
   199  	return err
   200  }
   201  
   202  func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, cc codec.Writer) error {
   203  	defer router.freeRequest(req)
   204  
   205  	function := mtype.method.Func
   206  	var returnValues []reflect.Value
   207  
   208  	r := &rpcRequest{
   209  		service:     req.msg.Target,
   210  		contentType: req.msg.Header["Content-Type"],
   211  		method:      req.msg.Method,
   212  		endpoint:    req.msg.Endpoint,
   213  		body:        req.msg.Body,
   214  		header:      req.msg.Header,
   215  	}
   216  
   217  	// only set if not nil
   218  	if argv.IsValid() {
   219  		r.rawBody = argv.Interface()
   220  	}
   221  
   222  	if !mtype.stream {
   223  		fn := func(ctx context.Context, req Request, rsp interface{}) error {
   224  			returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(argv.Interface()), reflect.ValueOf(rsp)})
   225  
   226  			// The return value for the method is an error.
   227  			if err := returnValues[0].Interface(); err != nil {
   228  				return err.(error)
   229  			}
   230  
   231  			return nil
   232  		}
   233  
   234  		// wrap the handler
   235  		for i := len(router.hdlrWrappers); i > 0; i-- {
   236  			fn = router.hdlrWrappers[i-1](fn)
   237  		}
   238  
   239  		// execute handler
   240  		if err := fn(ctx, r, replyv.Interface()); err != nil {
   241  			return err
   242  		}
   243  
   244  		// send response
   245  		return router.sendResponse(sending, req, replyv.Interface(), cc, true)
   246  	}
   247  
   248  	// declare a local error to see if we errored out already
   249  	// keep track of the type, to make sure we return
   250  	// the same one consistently
   251  	rawStream := &rpcStream{
   252  		context: ctx,
   253  		codec:   cc.(codec.Codec),
   254  		request: r,
   255  		id:      req.msg.Id,
   256  	}
   257  
   258  	// Invoke the method, providing a new value for the reply.
   259  	fn := func(ctx context.Context, req Request, stream interface{}) error {
   260  		returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)})
   261  		if err := returnValues[0].Interface(); err != nil {
   262  			// the function returned an error, we use that
   263  			return err.(error)
   264  		} else if serr := rawStream.Error(); serr == io.EOF || serr == io.ErrUnexpectedEOF {
   265  			return nil
   266  		} else {
   267  			// no error, we send the special EOS error
   268  			return lastStreamResponseError
   269  		}
   270  	}
   271  
   272  	// wrap the handler
   273  	for i := len(router.hdlrWrappers); i > 0; i-- {
   274  		fn = router.hdlrWrappers[i-1](fn)
   275  	}
   276  
   277  	// client.Stream request
   278  	r.stream = true
   279  
   280  	// execute handler
   281  	return fn(ctx, r, rawStream)
   282  }
   283  
   284  func (m *methodType) prepareContext(ctx context.Context) reflect.Value {
   285  	if contextv := reflect.ValueOf(ctx); contextv.IsValid() {
   286  		return contextv
   287  	}
   288  	return reflect.Zero(m.ContextType)
   289  }
   290  
   291  func (router *router) getRequest() *request {
   292  	router.reqLock.Lock()
   293  	req := router.freeReq
   294  	if req == nil {
   295  		req = new(request)
   296  	} else {
   297  		router.freeReq = req.next
   298  		*req = request{}
   299  	}
   300  	router.reqLock.Unlock()
   301  	return req
   302  }
   303  
   304  func (router *router) freeRequest(req *request) {
   305  	router.reqLock.Lock()
   306  	req.next = router.freeReq
   307  	router.freeReq = req
   308  	router.reqLock.Unlock()
   309  }
   310  
   311  func (router *router) getResponse() *response {
   312  	router.respLock.Lock()
   313  	resp := router.freeResp
   314  	if resp == nil {
   315  		resp = new(response)
   316  	} else {
   317  		router.freeResp = resp.next
   318  		*resp = response{}
   319  	}
   320  	router.respLock.Unlock()
   321  	return resp
   322  }
   323  
   324  func (router *router) freeResponse(resp *response) {
   325  	router.respLock.Lock()
   326  	resp.next = router.freeResp
   327  	router.freeResp = resp
   328  	router.respLock.Unlock()
   329  }
   330  
   331  func (router *router) readRequest(r Request) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) {
   332  	cc := r.Codec()
   333  
   334  	service, mtype, req, keepReading, err = router.readHeader(cc)
   335  	if err != nil {
   336  		if !keepReading {
   337  			return
   338  		}
   339  		// discard body
   340  		cc.ReadBody(nil)
   341  		return
   342  	}
   343  	// is it a streaming request? then we don't read the body
   344  	if mtype.stream {
   345  		if cc.(codec.Codec).String() != "grpc" {
   346  			cc.ReadBody(nil)
   347  		}
   348  		return
   349  	}
   350  
   351  	// Decode the argument value.
   352  	argIsValue := false // if true, need to indirect before calling.
   353  	if mtype.ArgType.Kind() == reflect.Ptr {
   354  		argv = reflect.New(mtype.ArgType.Elem())
   355  	} else {
   356  		argv = reflect.New(mtype.ArgType)
   357  		argIsValue = true
   358  	}
   359  	// argv guaranteed to be a pointer now.
   360  	if err = cc.ReadBody(argv.Interface()); err != nil {
   361  		return
   362  	}
   363  	if argIsValue {
   364  		argv = argv.Elem()
   365  	}
   366  
   367  	if !mtype.stream {
   368  		replyv = reflect.New(mtype.ReplyType.Elem())
   369  	}
   370  	return
   371  }
   372  
   373  func (router *router) readHeader(cc codec.Reader) (service *service, mtype *methodType, req *request, keepReading bool, err error) {
   374  	// Grab the request header.
   375  	msg := new(codec.Message)
   376  	msg.Type = codec.Request
   377  	req = router.getRequest()
   378  	req.msg = msg
   379  
   380  	err = cc.ReadHeader(msg, msg.Type)
   381  	if err != nil {
   382  		req = nil
   383  		if err == io.EOF || err == io.ErrUnexpectedEOF {
   384  			return
   385  		}
   386  		err = errors.New("rpc: router cannot decode request: " + err.Error())
   387  		return
   388  	}
   389  
   390  	// We read the header successfully. If we see an error now,
   391  	// we can still recover and move on to the next request.
   392  	keepReading = true
   393  
   394  	serviceMethod := strings.Split(req.msg.Endpoint, ".")
   395  	if len(serviceMethod) != 2 {
   396  		err = errors.New("rpc: service/endpoint request ill-formed: " + req.msg.Endpoint)
   397  		return
   398  	}
   399  	// Look up the request.
   400  	router.mu.Lock()
   401  	service = router.serviceMap[serviceMethod[0]]
   402  	router.mu.Unlock()
   403  	if service == nil {
   404  		err = errors.New("rpc: can't find service " + serviceMethod[0])
   405  		return
   406  	}
   407  	mtype = service.method[serviceMethod[1]]
   408  	if mtype == nil {
   409  		err = errors.New("rpc: can't find method " + serviceMethod[1])
   410  	}
   411  	return
   412  }
   413  
   414  func (router *router) NewHandler(h interface{}, opts ...HandlerOption) Handler {
   415  	return newRpcHandler(h, opts...)
   416  }
   417  
   418  func (router *router) Handle(h Handler) error {
   419  	router.mu.Lock()
   420  	defer router.mu.Unlock()
   421  	if router.serviceMap == nil {
   422  		router.serviceMap = make(map[string]*service)
   423  	}
   424  
   425  	if len(h.Name()) == 0 {
   426  		return errors.New("rpc.Handle: handler has no name")
   427  	}
   428  	if !isExported(h.Name()) {
   429  		return errors.New("rpc.Handle: type " + h.Name() + " is not exported")
   430  	}
   431  
   432  	rcvr := h.Handler()
   433  	s := new(service)
   434  	s.typ = reflect.TypeOf(rcvr)
   435  	s.rcvr = reflect.ValueOf(rcvr)
   436  
   437  	// check name
   438  	if _, present := router.serviceMap[h.Name()]; present {
   439  		return errors.New("rpc.Handle: service already defined: " + h.Name())
   440  	}
   441  
   442  	s.name = h.Name()
   443  	s.method = make(map[string]*methodType)
   444  
   445  	// Install the methods
   446  	for m := 0; m < s.typ.NumMethod(); m++ {
   447  		method := s.typ.Method(m)
   448  		if mt := prepareMethod(method); mt != nil {
   449  			s.method[method.Name] = mt
   450  		}
   451  	}
   452  
   453  	// Check there are methods
   454  	if len(s.method) == 0 {
   455  		return errors.New("rpc Register: type " + s.name + " has no exported methods of suitable type")
   456  	}
   457  
   458  	// save handler
   459  	router.serviceMap[s.name] = s
   460  	return nil
   461  }
   462  
   463  func (router *router) ServeRequest(ctx context.Context, r Request, rsp Response) error {
   464  	sending := new(sync.Mutex)
   465  	service, mtype, req, argv, replyv, keepReading, err := router.readRequest(r)
   466  	if err != nil {
   467  		if !keepReading {
   468  			return err
   469  		}
   470  		// send a response if we actually managed to read a header.
   471  		if req != nil {
   472  			router.freeRequest(req)
   473  		}
   474  		return err
   475  	}
   476  	return service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec())
   477  }
   478  
   479  func (router *router) NewSubscriber(topic string, handler interface{}, opts ...SubscriberOption) Subscriber {
   480  	return newSubscriber(topic, handler, opts...)
   481  }
   482  
   483  func (router *router) Subscribe(s Subscriber) error {
   484  	sub, ok := s.(*subscriber)
   485  	if !ok {
   486  		return fmt.Errorf("invalid subscriber: expected *subscriber")
   487  	}
   488  	if len(sub.handlers) == 0 {
   489  		return fmt.Errorf("invalid subscriber: no handler functions")
   490  	}
   491  
   492  	if err := validateSubscriber(sub); err != nil {
   493  		return err
   494  	}
   495  
   496  	router.su.Lock()
   497  	defer router.su.Unlock()
   498  
   499  	// append to subscribers
   500  	subs := router.subscribers[sub.Topic()]
   501  	subs = append(subs, sub)
   502  	router.subscribers[sub.Topic()] = subs
   503  
   504  	return nil
   505  }
   506  
   507  func (router *router) ProcessMessage(ctx context.Context, msg Message) (err error) {
   508  	defer func() {
   509  		// recover any panics
   510  		if r := recover(); r != nil {
   511  			log.Errorf("panic recovered: %v", r)
   512  			log.Error(string(debug.Stack()))
   513  			err = merrors.InternalServerError("go.micro.server", "panic recovered: %v", r)
   514  		}
   515  	}()
   516  
   517  	router.su.RLock()
   518  	// get the subscribers by topic
   519  	subs, ok := router.subscribers[msg.Topic()]
   520  	// unlock since we only need to get the subs
   521  	router.su.RUnlock()
   522  	if !ok {
   523  		return nil
   524  	}
   525  
   526  	var errResults []string
   527  
   528  	// we may have multiple subscribers for the topic
   529  	for _, sub := range subs {
   530  		// we may have multiple handlers per subscriber
   531  		for i := 0; i < len(sub.handlers); i++ {
   532  			// get the handler
   533  			handler := sub.handlers[i]
   534  
   535  			var isVal bool
   536  			var req reflect.Value
   537  
   538  			// check whether the handler is a pointer
   539  			if handler.reqType.Kind() == reflect.Ptr {
   540  				req = reflect.New(handler.reqType.Elem())
   541  			} else {
   542  				req = reflect.New(handler.reqType)
   543  				isVal = true
   544  			}
   545  
   546  			// if its a value get the element
   547  			if isVal {
   548  				req = req.Elem()
   549  			}
   550  
   551  			cc := msg.Codec()
   552  
   553  			// read the header. mostly a noop
   554  			if err = cc.ReadHeader(&codec.Message{}, codec.Event); err != nil {
   555  				return err
   556  			}
   557  
   558  			// read the body into the handler request value
   559  			if err = cc.ReadBody(req.Interface()); err != nil {
   560  				return err
   561  			}
   562  
   563  			// create the handler which will honour the SubscriberFunc type
   564  			fn := func(ctx context.Context, msg Message) error {
   565  				var vals []reflect.Value
   566  				if sub.typ.Kind() != reflect.Func {
   567  					vals = append(vals, sub.rcvr)
   568  				}
   569  				if handler.ctxType != nil {
   570  					vals = append(vals, reflect.ValueOf(ctx))
   571  				}
   572  
   573  				// values to pass the handler
   574  				vals = append(vals, reflect.ValueOf(msg.Payload()))
   575  
   576  				// execute the actuall call of the handler
   577  				returnValues := handler.method.Call(vals)
   578  				if rerr := returnValues[0].Interface(); rerr != nil {
   579  					err = rerr.(error)
   580  				}
   581  				return err
   582  			}
   583  
   584  			// wrap with subscriber wrappers
   585  			for i := len(router.subWrappers); i > 0; i-- {
   586  				fn = router.subWrappers[i-1](fn)
   587  			}
   588  
   589  			// create new rpc message
   590  			rpcMsg := &rpcMessage{
   591  				topic:       msg.Topic(),
   592  				contentType: msg.ContentType(),
   593  				payload:     req.Interface(),
   594  				codec:       msg.(*rpcMessage).codec,
   595  				header:      msg.Header(),
   596  				body:        msg.Body(),
   597  			}
   598  
   599  			// execute the message handler
   600  			if err = fn(ctx, rpcMsg); err != nil {
   601  				errResults = append(errResults, err.Error())
   602  			}
   603  		}
   604  	}
   605  
   606  	// if no errors just return
   607  	if len(errResults) > 0 {
   608  		err = merrors.InternalServerError("go.micro.server", "subscriber error: %v", strings.Join(errResults, "\n"))
   609  	}
   610  
   611  	return err
   612  }