gitee.com/liuxuezhan/go-micro-v1.18.0@v1.0.0/server/rpc_router.go (about)

     1  package server
     2  
     3  // Copyright 2009 The Go Authors. All rights reserved.
     4  // Use of this source code is governed by a BSD-style
     5  // license that can be found in the LICENSE file.
     6  //
     7  // Meh, we need to get rid of this shit
     8  
     9  import (
    10  	"context"
    11  	"errors"
    12  	"fmt"
    13  	"io"
    14  	"reflect"
    15  	"runtime/debug"
    16  	"strings"
    17  	"sync"
    18  	"unicode"
    19  	"unicode/utf8"
    20  
    21  	"gitee.com/liuxuezhan/go-micro-v1.18.0/codec"
    22  	merrors "gitee.com/liuxuezhan/go-micro-v1.18.0/errors"
    23  	"gitee.com/liuxuezhan/go-micro-v1.18.0/util/log"
    24  )
    25  
    26  var (
    27  	lastStreamResponseError = errors.New("EOS")
    28  	// A value sent as a placeholder for the server's response value when the server
    29  	// receives an invalid request. It is never decoded by the client since the Response
    30  	// contains an error when it is used.
    31  	invalidRequest = struct{}{}
    32  
    33  	// Precompute the reflect type for error. Can't use error directly
    34  	// because Typeof takes an empty interface value. This is annoying.
    35  	typeOfError = reflect.TypeOf((*error)(nil)).Elem()
    36  )
    37  
    38  type methodType struct {
    39  	sync.Mutex  // protects counters
    40  	method      reflect.Method
    41  	ArgType     reflect.Type
    42  	ReplyType   reflect.Type
    43  	ContextType reflect.Type
    44  	stream      bool
    45  }
    46  
    47  type service struct {
    48  	name   string                 // name of service
    49  	rcvr   reflect.Value          // receiver of methods for the service
    50  	typ    reflect.Type           // type of the receiver
    51  	method map[string]*methodType // registered methods
    52  }
    53  
    54  type request struct {
    55  	msg  *codec.Message
    56  	next *request // for free list in Server
    57  }
    58  
    59  type response struct {
    60  	msg  *codec.Message
    61  	next *response // for free list in Server
    62  }
    63  
    64  // router represents an RPC router.
    65  type router struct {
    66  	name string
    67  
    68  	mu         sync.Mutex // protects the serviceMap
    69  	serviceMap map[string]*service
    70  
    71  	reqLock sync.Mutex // protects freeReq
    72  	freeReq *request
    73  
    74  	respLock sync.Mutex // protects freeResp
    75  	freeResp *response
    76  
    77  	// handler wrappers
    78  	hdlrWrappers []HandlerWrapper
    79  	// subscriber wrappers
    80  	subWrappers []SubscriberWrapper
    81  
    82  	su          sync.RWMutex
    83  	subscribers map[string][]*subscriber
    84  }
    85  
    86  // rpcRouter encapsulates functions that become a server.Router
    87  type rpcRouter struct {
    88  	h func(context.Context, Request, interface{}) error
    89  	m func(context.Context, Message) error
    90  }
    91  
    92  func (r rpcRouter) ProcessMessage(ctx context.Context, msg Message) error {
    93  	return r.m(ctx, msg)
    94  }
    95  
    96  func (r rpcRouter) ServeRequest(ctx context.Context, req Request, rsp Response) error {
    97  	return r.h(ctx, req, rsp)
    98  }
    99  
   100  func newRpcRouter() *router {
   101  	return &router{
   102  		serviceMap:  make(map[string]*service),
   103  		subscribers: make(map[string][]*subscriber),
   104  	}
   105  }
   106  
   107  // Is this an exported - upper case - name?
   108  func isExported(name string) bool {
   109  	rune, _ := utf8.DecodeRuneInString(name)
   110  	return unicode.IsUpper(rune)
   111  }
   112  
   113  // Is this type exported or a builtin?
   114  func isExportedOrBuiltinType(t reflect.Type) bool {
   115  	for t.Kind() == reflect.Ptr {
   116  		t = t.Elem()
   117  	}
   118  	// PkgPath will be non-empty even for an exported type,
   119  	// so we need to check the type name as well.
   120  	return isExported(t.Name()) || t.PkgPath() == ""
   121  }
   122  
   123  // prepareMethod returns a methodType for the provided method or nil
   124  // in case if the method was unsuitable.
   125  func prepareMethod(method reflect.Method) *methodType {
   126  	mtype := method.Type
   127  	mname := method.Name
   128  	var replyType, argType, contextType reflect.Type
   129  	var stream bool
   130  
   131  	// Method must be exported.
   132  	if method.PkgPath != "" {
   133  		return nil
   134  	}
   135  
   136  	switch mtype.NumIn() {
   137  	case 3:
   138  		// assuming streaming
   139  		argType = mtype.In(2)
   140  		contextType = mtype.In(1)
   141  		stream = true
   142  	case 4:
   143  		// method that takes a context
   144  		argType = mtype.In(2)
   145  		replyType = mtype.In(3)
   146  		contextType = mtype.In(1)
   147  	default:
   148  		log.Log("method", mname, "of", mtype, "has wrong number of ins:", mtype.NumIn())
   149  		return nil
   150  	}
   151  
   152  	if stream {
   153  		// check stream type
   154  		streamType := reflect.TypeOf((*Stream)(nil)).Elem()
   155  		if !argType.Implements(streamType) {
   156  			log.Log(mname, "argument does not implement Stream interface:", argType)
   157  			return nil
   158  		}
   159  	} else {
   160  		// if not stream check the replyType
   161  
   162  		// First arg need not be a pointer.
   163  		if !isExportedOrBuiltinType(argType) {
   164  			log.Log(mname, "argument type not exported:", argType)
   165  			return nil
   166  		}
   167  
   168  		if replyType.Kind() != reflect.Ptr {
   169  			log.Log("method", mname, "reply type not a pointer:", replyType)
   170  			return nil
   171  		}
   172  
   173  		// Reply type must be exported.
   174  		if !isExportedOrBuiltinType(replyType) {
   175  			log.Log("method", mname, "reply type not exported:", replyType)
   176  			return nil
   177  		}
   178  	}
   179  
   180  	// Method needs one out.
   181  	if mtype.NumOut() != 1 {
   182  		log.Log("method", mname, "has wrong number of outs:", mtype.NumOut())
   183  		return nil
   184  	}
   185  	// The return type of the method must be error.
   186  	if returnType := mtype.Out(0); returnType != typeOfError {
   187  		log.Log("method", mname, "returns", returnType.String(), "not error")
   188  		return nil
   189  	}
   190  	return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
   191  }
   192  
   193  func (router *router) sendResponse(sending sync.Locker, req *request, reply interface{}, cc codec.Writer, last bool) error {
   194  	msg := new(codec.Message)
   195  	msg.Type = codec.Response
   196  	resp := router.getResponse()
   197  	resp.msg = msg
   198  
   199  	resp.msg.Id = req.msg.Id
   200  	sending.Lock()
   201  	err := cc.Write(resp.msg, reply)
   202  	sending.Unlock()
   203  	router.freeResponse(resp)
   204  	return err
   205  }
   206  
   207  func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, cc codec.Writer) error {
   208  	defer router.freeRequest(req)
   209  
   210  	function := mtype.method.Func
   211  	var returnValues []reflect.Value
   212  
   213  	r := &rpcRequest{
   214  		service:     req.msg.Target,
   215  		contentType: req.msg.Header["Content-Type"],
   216  		method:      req.msg.Method,
   217  		endpoint:    req.msg.Endpoint,
   218  		body:        req.msg.Body,
   219  	}
   220  
   221  	// only set if not nil
   222  	if argv.IsValid() {
   223  		r.rawBody = argv.Interface()
   224  	}
   225  
   226  	if !mtype.stream {
   227  		fn := func(ctx context.Context, req Request, rsp interface{}) error {
   228  			returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(argv.Interface()), reflect.ValueOf(rsp)})
   229  
   230  			// The return value for the method is an error.
   231  			if err := returnValues[0].Interface(); err != nil {
   232  				return err.(error)
   233  			}
   234  
   235  			return nil
   236  		}
   237  
   238  		// wrap the handler
   239  		for i := len(router.hdlrWrappers); i > 0; i-- {
   240  			fn = router.hdlrWrappers[i-1](fn)
   241  		}
   242  
   243  		// execute handler
   244  		if err := fn(ctx, r, replyv.Interface()); err != nil {
   245  			return err
   246  		}
   247  
   248  		// send response
   249  		return router.sendResponse(sending, req, replyv.Interface(), cc, true)
   250  	}
   251  
   252  	// declare a local error to see if we errored out already
   253  	// keep track of the type, to make sure we return
   254  	// the same one consistently
   255  	rawStream := &rpcStream{
   256  		context: ctx,
   257  		codec:   cc.(codec.Codec),
   258  		request: r,
   259  		id:      req.msg.Id,
   260  	}
   261  
   262  	// Invoke the method, providing a new value for the reply.
   263  	fn := func(ctx context.Context, req Request, stream interface{}) error {
   264  		returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)})
   265  		if err := returnValues[0].Interface(); err != nil {
   266  			// the function returned an error, we use that
   267  			return err.(error)
   268  		} else if serr := rawStream.Error(); serr == io.EOF || serr == io.ErrUnexpectedEOF {
   269  			return nil
   270  		} else {
   271  			// no error, we send the special EOS error
   272  			return lastStreamResponseError
   273  		}
   274  	}
   275  
   276  	// wrap the handler
   277  	for i := len(router.hdlrWrappers); i > 0; i-- {
   278  		fn = router.hdlrWrappers[i-1](fn)
   279  	}
   280  
   281  	// client.Stream request
   282  	r.stream = true
   283  
   284  	// execute handler
   285  	return fn(ctx, r, rawStream)
   286  }
   287  
   288  func (m *methodType) prepareContext(ctx context.Context) reflect.Value {
   289  	if contextv := reflect.ValueOf(ctx); contextv.IsValid() {
   290  		return contextv
   291  	}
   292  	return reflect.Zero(m.ContextType)
   293  }
   294  
   295  func (router *router) getRequest() *request {
   296  	router.reqLock.Lock()
   297  	req := router.freeReq
   298  	if req == nil {
   299  		req = new(request)
   300  	} else {
   301  		router.freeReq = req.next
   302  		*req = request{}
   303  	}
   304  	router.reqLock.Unlock()
   305  	return req
   306  }
   307  
   308  func (router *router) freeRequest(req *request) {
   309  	router.reqLock.Lock()
   310  	req.next = router.freeReq
   311  	router.freeReq = req
   312  	router.reqLock.Unlock()
   313  }
   314  
   315  func (router *router) getResponse() *response {
   316  	router.respLock.Lock()
   317  	resp := router.freeResp
   318  	if resp == nil {
   319  		resp = new(response)
   320  	} else {
   321  		router.freeResp = resp.next
   322  		*resp = response{}
   323  	}
   324  	router.respLock.Unlock()
   325  	return resp
   326  }
   327  
   328  func (router *router) freeResponse(resp *response) {
   329  	router.respLock.Lock()
   330  	resp.next = router.freeResp
   331  	router.freeResp = resp
   332  	router.respLock.Unlock()
   333  }
   334  
   335  func (router *router) readRequest(r Request) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) {
   336  	cc := r.Codec()
   337  
   338  	service, mtype, req, keepReading, err = router.readHeader(cc)
   339  	if err != nil {
   340  		if !keepReading {
   341  			return
   342  		}
   343  		// discard body
   344  		cc.ReadBody(nil)
   345  		return
   346  	}
   347  	// is it a streaming request? then we don't read the body
   348  	if mtype.stream {
   349  		cc.ReadBody(nil)
   350  		return
   351  	}
   352  
   353  	// Decode the argument value.
   354  	argIsValue := false // if true, need to indirect before calling.
   355  	if mtype.ArgType.Kind() == reflect.Ptr {
   356  		argv = reflect.New(mtype.ArgType.Elem())
   357  	} else {
   358  		argv = reflect.New(mtype.ArgType)
   359  		argIsValue = true
   360  	}
   361  	// argv guaranteed to be a pointer now.
   362  	if err = cc.ReadBody(argv.Interface()); err != nil {
   363  		return
   364  	}
   365  	if argIsValue {
   366  		argv = argv.Elem()
   367  	}
   368  
   369  	if !mtype.stream {
   370  		replyv = reflect.New(mtype.ReplyType.Elem())
   371  	}
   372  	return
   373  }
   374  
   375  func (router *router) readHeader(cc codec.Reader) (service *service, mtype *methodType, req *request, keepReading bool, err error) {
   376  	// Grab the request header.
   377  	msg := new(codec.Message)
   378  	msg.Type = codec.Request
   379  	req = router.getRequest()
   380  	req.msg = msg
   381  
   382  	err = cc.ReadHeader(msg, msg.Type)
   383  	if err != nil {
   384  		req = nil
   385  		if err == io.EOF || err == io.ErrUnexpectedEOF {
   386  			return
   387  		}
   388  		err = errors.New("rpc: router cannot decode request: " + err.Error())
   389  		return
   390  	}
   391  
   392  	// We read the header successfully. If we see an error now,
   393  	// we can still recover and move on to the next request.
   394  	keepReading = true
   395  
   396  	serviceMethod := strings.Split(req.msg.Endpoint, ".")
   397  	if len(serviceMethod) != 2 {
   398  		err = errors.New("rpc: service/endpoint request ill-formed: " + req.msg.Endpoint)
   399  		return
   400  	}
   401  	// Look up the request.
   402  	router.mu.Lock()
   403  	service = router.serviceMap[serviceMethod[0]]
   404  	router.mu.Unlock()
   405  	if service == nil {
   406  		err = errors.New("rpc: can't find service " + serviceMethod[0])
   407  		return
   408  	}
   409  	mtype = service.method[serviceMethod[1]]
   410  	if mtype == nil {
   411  		err = errors.New("rpc: can't find method " + serviceMethod[1])
   412  	}
   413  	return
   414  }
   415  
   416  func (router *router) NewHandler(h interface{}, opts ...HandlerOption) Handler {
   417  	return newRpcHandler(h, opts...)
   418  }
   419  
   420  func (router *router) Handle(h Handler) error {
   421  	router.mu.Lock()
   422  	defer router.mu.Unlock()
   423  	if router.serviceMap == nil {
   424  		router.serviceMap = make(map[string]*service)
   425  	}
   426  
   427  	if len(h.Name()) == 0 {
   428  		return errors.New("rpc.Handle: handler has no name")
   429  	}
   430  	if !isExported(h.Name()) {
   431  		return errors.New("rpc.Handle: type " + h.Name() + " is not exported")
   432  	}
   433  
   434  	rcvr := h.Handler()
   435  	s := new(service)
   436  	s.typ = reflect.TypeOf(rcvr)
   437  	s.rcvr = reflect.ValueOf(rcvr)
   438  
   439  	// check name
   440  	if _, present := router.serviceMap[h.Name()]; present {
   441  		return errors.New("rpc.Handle: service already defined: " + h.Name())
   442  	}
   443  
   444  	s.name = h.Name()
   445  	s.method = make(map[string]*methodType)
   446  
   447  	// Install the methods
   448  	for m := 0; m < s.typ.NumMethod(); m++ {
   449  		method := s.typ.Method(m)
   450  		if mt := prepareMethod(method); mt != nil {
   451  			s.method[method.Name] = mt
   452  		}
   453  	}
   454  
   455  	// Check there are methods
   456  	if len(s.method) == 0 {
   457  		return errors.New("rpc Register: type " + s.name + " has no exported methods of suitable type")
   458  	}
   459  
   460  	// save handler
   461  	router.serviceMap[s.name] = s
   462  	return nil
   463  }
   464  
   465  func (router *router) ServeRequest(ctx context.Context, r Request, rsp Response) error {
   466  	sending := new(sync.Mutex)
   467  	service, mtype, req, argv, replyv, keepReading, err := router.readRequest(r)
   468  	if err != nil {
   469  		if !keepReading {
   470  			return err
   471  		}
   472  		// send a response if we actually managed to read a header.
   473  		if req != nil {
   474  			router.freeRequest(req)
   475  		}
   476  		return err
   477  	}
   478  	return service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec())
   479  }
   480  
   481  func (router *router) NewSubscriber(topic string, handler interface{}, opts ...SubscriberOption) Subscriber {
   482  	return newSubscriber(topic, handler, opts...)
   483  }
   484  
   485  func (router *router) Subscribe(s Subscriber) error {
   486  	sub, ok := s.(*subscriber)
   487  	if !ok {
   488  		return fmt.Errorf("invalid subscriber: expected *subscriber")
   489  	}
   490  	if len(sub.handlers) == 0 {
   491  		return fmt.Errorf("invalid subscriber: no handler functions")
   492  	}
   493  
   494  	if err := validateSubscriber(sub); err != nil {
   495  		return err
   496  	}
   497  
   498  	router.su.Lock()
   499  	defer router.su.Unlock()
   500  
   501  	// append to subscribers
   502  	subs := router.subscribers[sub.Topic()]
   503  	subs = append(subs, sub)
   504  	router.subscribers[sub.Topic()] = subs
   505  
   506  	return nil
   507  }
   508  
   509  func (router *router) ProcessMessage(ctx context.Context, msg Message) error {
   510  	var err error
   511  
   512  	defer func() {
   513  		// recover any panics
   514  		if r := recover(); r != nil {
   515  			log.Log("panic recovered: ", r)
   516  			log.Log(string(debug.Stack()))
   517  			err = merrors.InternalServerError("go.micro.server", "panic recovered: %v", r)
   518  		}
   519  	}()
   520  
   521  	router.su.RLock()
   522  
   523  	// get the subscribers by topic
   524  	subs, ok := router.subscribers[msg.Topic()]
   525  	if !ok {
   526  		router.su.RUnlock()
   527  		return nil
   528  	}
   529  
   530  	// unlock since we only need to get the subs
   531  	router.su.RUnlock()
   532  
   533  	var errResults []string
   534  
   535  	// we may have multiple subscribers for the topic
   536  	for _, sub := range subs {
   537  		// we may have multiple handlers per subscriber
   538  		for i := 0; i < len(sub.handlers); i++ {
   539  			// get the handler
   540  			handler := sub.handlers[i]
   541  
   542  			var isVal bool
   543  			var req reflect.Value
   544  
   545  			// check whether the handler is a pointer
   546  			if handler.reqType.Kind() == reflect.Ptr {
   547  				req = reflect.New(handler.reqType.Elem())
   548  			} else {
   549  				req = reflect.New(handler.reqType)
   550  				isVal = true
   551  			}
   552  
   553  			// if its a value get the element
   554  			if isVal {
   555  				req = req.Elem()
   556  			}
   557  
   558  			if handler.reqType.Kind() == reflect.Ptr {
   559  				req = reflect.New(handler.reqType.Elem())
   560  			} else {
   561  				req = reflect.New(handler.reqType)
   562  				isVal = true
   563  			}
   564  
   565  			// if its a value get the element
   566  			if isVal {
   567  				req = req.Elem()
   568  			}
   569  
   570  			cc := msg.Codec()
   571  
   572  			// read the header. mostly a noop
   573  			if err = cc.ReadHeader(&codec.Message{}, codec.Event); err != nil {
   574  				return err
   575  			}
   576  
   577  			// read the body into the handler request value
   578  			if err = cc.ReadBody(req.Interface()); err != nil {
   579  				return err
   580  			}
   581  
   582  			// create the handler which will honour the SubscriberFunc type
   583  			fn := func(ctx context.Context, msg Message) error {
   584  				var vals []reflect.Value
   585  				if sub.typ.Kind() != reflect.Func {
   586  					vals = append(vals, sub.rcvr)
   587  				}
   588  				if handler.ctxType != nil {
   589  					vals = append(vals, reflect.ValueOf(ctx))
   590  				}
   591  
   592  				// values to pass the handler
   593  				vals = append(vals, reflect.ValueOf(msg.Payload()))
   594  
   595  				// execute the actuall call of the handler
   596  				returnValues := handler.method.Call(vals)
   597  				if rerr := returnValues[0].Interface(); rerr != nil {
   598  					err = rerr.(error)
   599  				}
   600  				return err
   601  			}
   602  
   603  			// wrap with subscriber wrappers
   604  			for i := len(router.subWrappers); i > 0; i-- {
   605  				fn = router.subWrappers[i-1](fn)
   606  			}
   607  
   608  			// create new rpc message
   609  			rpcMsg := &rpcMessage{
   610  				topic:       msg.Topic(),
   611  				contentType: msg.ContentType(),
   612  				payload:     req.Interface(),
   613  				codec:       msg.(*rpcMessage).codec,
   614  				header:      msg.Header(),
   615  				body:        msg.Body(),
   616  			}
   617  
   618  			// execute the message handler
   619  			if err = fn(ctx, rpcMsg); err != nil {
   620  				errResults = append(errResults, err.Error())
   621  			}
   622  		}
   623  	}
   624  
   625  	// if no errors just return
   626  	if len(errResults) > 0 {
   627  		err = merrors.InternalServerError("go.micro.server", "subscriber error: %v", strings.Join(errResults, "\n"))
   628  	}
   629  
   630  	return err
   631  }