github.com/tickoalcantara12/micro/v3@v3.0.0-20221007104245-9d75b9bcbab9/service/server/mucp/rpc_router.go (about)

     1  // Copyright 2020 Asim Aslam
     2  //
     3  // Licensed under the Apache License, Version 2.0 (the "License");
     4  // you may not use this file except in compliance with the License.
     5  // You may obtain a copy of the License at
     6  //
     7  //     https://www.apache.org/licenses/LICENSE-2.0
     8  //
     9  // Unless required by applicable law or agreed to in writing, software
    10  // distributed under the License is distributed on an "AS IS" BASIS,
    11  // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
    12  // See the License for the specific language governing permissions and
    13  // limitations under the License.
    14  //
    15  // Original source: github.com/micro/go-micro/v3/server/mucp/rpc_router.go
    16  
    17  package mucp
    18  
    19  // Copyright 2009 The Go Authors. All rights reserved.
    20  // Use of this source code is governed by a BSD-style
    21  // license that can be found in the LICENSE file.
    22  //
    23  // Meh, we need to get rid of this shit
    24  
    25  import (
    26  	"context"
    27  	"errors"
    28  	"fmt"
    29  	"io"
    30  	"reflect"
    31  	"runtime/debug"
    32  	"strings"
    33  	"sync"
    34  	"unicode"
    35  	"unicode/utf8"
    36  
    37  	merrors "github.com/tickoalcantara12/micro/v3/service/errors"
    38  	"github.com/tickoalcantara12/micro/v3/service/server"
    39  	"github.com/tickoalcantara12/micro/v3/util/codec"
    40  )
    41  
    42  var (
    43  	lastStreamResponseError = errors.New("EOS")
    44  
    45  	// Precompute the reflect type for error. Can't use error directly
    46  	// because Typeof takes an empty interface value. This is annoying.
    47  	typeOfError = reflect.TypeOf((*error)(nil)).Elem()
    48  )
    49  
    50  type methodType struct {
    51  	sync.Mutex  // protects counters
    52  	method      reflect.Method
    53  	ArgType     reflect.Type
    54  	ReplyType   reflect.Type
    55  	ContextType reflect.Type
    56  	stream      bool
    57  }
    58  
    59  type service struct {
    60  	name   string                 // name of service
    61  	rcvr   reflect.Value          // receiver of methods for the service
    62  	typ    reflect.Type           // type of the receiver
    63  	method map[string]*methodType // registered methods
    64  }
    65  
    66  type request struct {
    67  	msg  *codec.Message
    68  	next *request // for free list in Server
    69  }
    70  
    71  type response struct {
    72  	msg  *codec.Message
    73  	next *response // for free list in Server
    74  }
    75  
    76  // router represents an RPC router.
    77  type router struct {
    78  	name string
    79  
    80  	mu         sync.Mutex // protects the serviceMap
    81  	serviceMap map[string]*service
    82  
    83  	reqLock sync.Mutex // protects freeReq
    84  	freeReq *request
    85  
    86  	respLock sync.Mutex // protects freeResp
    87  	freeResp *response
    88  
    89  	// handler wrappers
    90  	hdlrWrappers []server.HandlerWrapper
    91  	// subscriber wrappers
    92  	subWrappers []server.SubscriberWrapper
    93  
    94  	su          sync.RWMutex
    95  	subscribers map[string][]*subscriber
    96  }
    97  
    98  // rpcRouter encapsulates functions that become a server.Router
    99  type rpcRouter struct {
   100  	h func(context.Context, server.Request, interface{}) error
   101  	m func(context.Context, server.Message) error
   102  }
   103  
   104  func (r rpcRouter) ProcessMessage(ctx context.Context, msg server.Message) error {
   105  	return r.m(ctx, msg)
   106  }
   107  
   108  func (r rpcRouter) ServeRequest(ctx context.Context, req server.Request, rsp server.Response) error {
   109  	return r.h(ctx, req, rsp)
   110  }
   111  
   112  func newRpcRouter() *router {
   113  	return &router{
   114  		serviceMap:  make(map[string]*service),
   115  		subscribers: make(map[string][]*subscriber),
   116  	}
   117  }
   118  
   119  // Is this an exported - upper case - name?
   120  func isExported(name string) bool {
   121  	rune, _ := utf8.DecodeRuneInString(name)
   122  	return unicode.IsUpper(rune)
   123  }
   124  
   125  // Is this type exported or a builtin?
   126  func isExportedOrBuiltinType(t reflect.Type) bool {
   127  	for t.Kind() == reflect.Ptr {
   128  		t = t.Elem()
   129  	}
   130  	// PkgPath will be non-empty even for an exported type,
   131  	// so we need to check the type name as well.
   132  	return isExported(t.Name()) || t.PkgPath() == ""
   133  }
   134  
   135  // prepareMethod returns a methodType for the provided method or nil
   136  // in case if the method was unsuitable.
   137  func prepareMethod(method reflect.Method) *methodType {
   138  	mtype := method.Type
   139  	mname := method.Name
   140  	var replyType, argType, contextType reflect.Type
   141  	var stream bool
   142  
   143  	// Method must be exported.
   144  	if method.PkgPath != "" {
   145  		return nil
   146  	}
   147  
   148  	switch mtype.NumIn() {
   149  	case 3:
   150  		// assuming streaming
   151  		argType = mtype.In(2)
   152  		contextType = mtype.In(1)
   153  		stream = true
   154  	case 4:
   155  		// method that takes a context
   156  		argType = mtype.In(2)
   157  		replyType = mtype.In(3)
   158  		contextType = mtype.In(1)
   159  	default:
   160  		log.Errorf("method %v of %v has wrong number of ins: %v", mname, mtype, mtype.NumIn())
   161  		return nil
   162  	}
   163  
   164  	if stream {
   165  		// check stream type
   166  		streamType := reflect.TypeOf((*server.Stream)(nil)).Elem()
   167  		if !argType.Implements(streamType) {
   168  			log.Errorf("%v argument does not implement Stream interface: %v", mname, argType)
   169  			return nil
   170  		}
   171  	} else {
   172  		// if not stream check the replyType
   173  
   174  		// First arg need not be a pointer.
   175  		if !isExportedOrBuiltinType(argType) {
   176  			log.Errorf("%v argument type not exported: %v", mname, argType)
   177  			return nil
   178  		}
   179  
   180  		if replyType.Kind() != reflect.Ptr {
   181  			log.Errorf("method %v reply type not a pointer: %v", mname, replyType)
   182  			return nil
   183  		}
   184  
   185  		// Reply type must be exported.
   186  		if !isExportedOrBuiltinType(replyType) {
   187  			log.Errorf("method %v reply type not exported: %v", mname, replyType)
   188  			return nil
   189  		}
   190  	}
   191  
   192  	// Method needs one out.
   193  	if mtype.NumOut() != 1 {
   194  		log.Errorf("method %v has wrong number of outs: %v", mname, mtype.NumOut())
   195  		return nil
   196  	}
   197  	// The return type of the method must be error.
   198  	if returnType := mtype.Out(0); returnType != typeOfError {
   199  		log.Errorf("method %v returns %v not error", mname, returnType.String())
   200  		return nil
   201  	}
   202  	return &methodType{method: method, ArgType: argType, ReplyType: replyType, ContextType: contextType, stream: stream}
   203  }
   204  
   205  func (router *router) sendResponse(sending sync.Locker, req *request, reply interface{}, cc codec.Writer, last bool) error {
   206  	msg := new(codec.Message)
   207  	msg.Type = codec.Response
   208  	resp := router.getResponse()
   209  	resp.msg = msg
   210  
   211  	resp.msg.Id = req.msg.Id
   212  	sending.Lock()
   213  	err := cc.Write(resp.msg, reply)
   214  	sending.Unlock()
   215  	router.freeResponse(resp)
   216  	return err
   217  }
   218  
   219  func (s *service) call(ctx context.Context, router *router, sending *sync.Mutex, mtype *methodType, req *request, argv, replyv reflect.Value, cc codec.Writer) error {
   220  	defer router.freeRequest(req)
   221  
   222  	function := mtype.method.Func
   223  	var returnValues []reflect.Value
   224  
   225  	r := &rpcRequest{
   226  		service:     req.msg.Target,
   227  		contentType: req.msg.Header["Content-Type"],
   228  		method:      req.msg.Method,
   229  		endpoint:    req.msg.Endpoint,
   230  		body:        req.msg.Body,
   231  		header:      req.msg.Header,
   232  	}
   233  
   234  	// only set if not nil
   235  	if argv.IsValid() {
   236  		r.rawBody = argv.Interface()
   237  	}
   238  
   239  	if !mtype.stream {
   240  		fn := func(ctx context.Context, req server.Request, rsp interface{}) error {
   241  			returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(argv.Interface()), reflect.ValueOf(rsp)})
   242  
   243  			// The return value for the method is an error.
   244  			if err := returnValues[0].Interface(); err != nil {
   245  				return err.(error)
   246  			}
   247  
   248  			return nil
   249  		}
   250  
   251  		// wrap the handler
   252  		for i := len(router.hdlrWrappers); i > 0; i-- {
   253  			fn = router.hdlrWrappers[i-1](fn)
   254  		}
   255  
   256  		// execute handler
   257  		if err := fn(ctx, r, replyv.Interface()); err != nil {
   258  			return err
   259  		}
   260  
   261  		// send response
   262  		return router.sendResponse(sending, req, replyv.Interface(), cc, true)
   263  	}
   264  
   265  	// declare a local error to see if we errored out already
   266  	// keep track of the type, to make sure we return
   267  	// the same one consistently
   268  	rawStream := &rpcStream{
   269  		context: ctx,
   270  		codec:   cc.(codec.Codec),
   271  		request: r,
   272  		id:      req.msg.Id,
   273  	}
   274  
   275  	// Invoke the method, providing a new value for the reply.
   276  	fn := func(ctx context.Context, req server.Request, stream interface{}) error {
   277  		returnValues = function.Call([]reflect.Value{s.rcvr, mtype.prepareContext(ctx), reflect.ValueOf(stream)})
   278  		if err := returnValues[0].Interface(); err != nil {
   279  			// the function returned an error, we use that
   280  			return err.(error)
   281  		} else if serr := rawStream.Error(); serr == io.EOF || serr == io.ErrUnexpectedEOF {
   282  			return nil
   283  		} else {
   284  			// no error, we send the special EOS error
   285  			return lastStreamResponseError
   286  		}
   287  	}
   288  
   289  	// wrap the handler
   290  	for i := len(router.hdlrWrappers); i > 0; i-- {
   291  		fn = router.hdlrWrappers[i-1](fn)
   292  	}
   293  
   294  	// client.Stream request
   295  	r.stream = true
   296  
   297  	// execute handler
   298  	return fn(ctx, r, rawStream)
   299  }
   300  
   301  func (m *methodType) prepareContext(ctx context.Context) reflect.Value {
   302  	if contextv := reflect.ValueOf(ctx); contextv.IsValid() {
   303  		return contextv
   304  	}
   305  	return reflect.Zero(m.ContextType)
   306  }
   307  
   308  func (router *router) getRequest() *request {
   309  	router.reqLock.Lock()
   310  	req := router.freeReq
   311  	if req == nil {
   312  		req = new(request)
   313  	} else {
   314  		router.freeReq = req.next
   315  		*req = request{}
   316  	}
   317  	router.reqLock.Unlock()
   318  	return req
   319  }
   320  
   321  func (router *router) freeRequest(req *request) {
   322  	router.reqLock.Lock()
   323  	req.next = router.freeReq
   324  	router.freeReq = req
   325  	router.reqLock.Unlock()
   326  }
   327  
   328  func (router *router) getResponse() *response {
   329  	router.respLock.Lock()
   330  	resp := router.freeResp
   331  	if resp == nil {
   332  		resp = new(response)
   333  	} else {
   334  		router.freeResp = resp.next
   335  		*resp = response{}
   336  	}
   337  	router.respLock.Unlock()
   338  	return resp
   339  }
   340  
   341  func (router *router) freeResponse(resp *response) {
   342  	router.respLock.Lock()
   343  	resp.next = router.freeResp
   344  	router.freeResp = resp
   345  	router.respLock.Unlock()
   346  }
   347  
   348  func (router *router) readRequest(r server.Request) (service *service, mtype *methodType, req *request, argv, replyv reflect.Value, keepReading bool, err error) {
   349  	cc := r.Codec()
   350  
   351  	service, mtype, req, keepReading, err = router.readHeader(cc)
   352  	if err != nil {
   353  		if !keepReading {
   354  			return
   355  		}
   356  		// discard body
   357  		cc.ReadBody(nil)
   358  		return
   359  	}
   360  	// is it a streaming request? then we don't read the body
   361  	if mtype.stream {
   362  		if cc.(codec.Codec).String() != "grpc" {
   363  			cc.ReadBody(nil)
   364  		}
   365  		return
   366  	}
   367  
   368  	// Decode the argument value.
   369  	argIsValue := false // if true, need to indirect before calling.
   370  	if mtype.ArgType.Kind() == reflect.Ptr {
   371  		argv = reflect.New(mtype.ArgType.Elem())
   372  	} else {
   373  		argv = reflect.New(mtype.ArgType)
   374  		argIsValue = true
   375  	}
   376  	// argv guaranteed to be a pointer now.
   377  	if err = cc.ReadBody(argv.Interface()); err != nil {
   378  		return
   379  	}
   380  	if argIsValue {
   381  		argv = argv.Elem()
   382  	}
   383  
   384  	if !mtype.stream {
   385  		replyv = reflect.New(mtype.ReplyType.Elem())
   386  	}
   387  	return
   388  }
   389  
   390  func (router *router) readHeader(cc codec.Reader) (service *service, mtype *methodType, req *request, keepReading bool, err error) {
   391  	// Grab the request header.
   392  	msg := new(codec.Message)
   393  	msg.Type = codec.Request
   394  	req = router.getRequest()
   395  	req.msg = msg
   396  
   397  	err = cc.ReadHeader(msg, msg.Type)
   398  	if err != nil {
   399  		req = nil
   400  		if err == io.EOF || err == io.ErrUnexpectedEOF {
   401  			return
   402  		}
   403  		err = errors.New("rpc: router cannot decode request: " + err.Error())
   404  		return
   405  	}
   406  
   407  	// We read the header successfully. If we see an error now,
   408  	// we can still recover and move on to the next request.
   409  	keepReading = true
   410  
   411  	serviceMethod := strings.Split(req.msg.Endpoint, ".")
   412  	if len(serviceMethod) != 2 {
   413  		err = errors.New("rpc: service/endpoint request ill-formed: " + req.msg.Endpoint)
   414  		return
   415  	}
   416  	// Look up the request.
   417  	router.mu.Lock()
   418  	service = router.serviceMap[serviceMethod[0]]
   419  	router.mu.Unlock()
   420  	if service == nil {
   421  		err = errors.New("rpc: can't find service " + serviceMethod[0])
   422  		return
   423  	}
   424  	mtype = service.method[serviceMethod[1]]
   425  	if mtype == nil {
   426  		err = errors.New("rpc: can't find method " + serviceMethod[1])
   427  	}
   428  	return
   429  }
   430  
   431  func (router *router) NewHandler(h interface{}, opts ...server.HandlerOption) server.Handler {
   432  	return newRpcHandler(h, opts...)
   433  }
   434  
   435  func (router *router) Handle(h server.Handler) error {
   436  	router.mu.Lock()
   437  	defer router.mu.Unlock()
   438  	if router.serviceMap == nil {
   439  		router.serviceMap = make(map[string]*service)
   440  	}
   441  
   442  	if len(h.Name()) == 0 {
   443  		return errors.New("rpc.Handle: handler has no name")
   444  	}
   445  	if !isExported(h.Name()) {
   446  		return errors.New("rpc.Handle: type " + h.Name() + " is not exported")
   447  	}
   448  
   449  	rcvr := h.Handler()
   450  	s := new(service)
   451  	s.typ = reflect.TypeOf(rcvr)
   452  	s.rcvr = reflect.ValueOf(rcvr)
   453  
   454  	// check name
   455  	if _, present := router.serviceMap[h.Name()]; present {
   456  		return errors.New("rpc.Handle: service already defined: " + h.Name())
   457  	}
   458  
   459  	s.name = h.Name()
   460  	s.method = make(map[string]*methodType)
   461  
   462  	// Install the methods
   463  	for m := 0; m < s.typ.NumMethod(); m++ {
   464  		method := s.typ.Method(m)
   465  		if mt := prepareMethod(method); mt != nil {
   466  			s.method[method.Name] = mt
   467  		}
   468  	}
   469  
   470  	// Check there are methods
   471  	if len(s.method) == 0 {
   472  		return errors.New("rpc Register: type " + s.name + " has no exported methods of suitable type")
   473  	}
   474  
   475  	// save handler
   476  	router.serviceMap[s.name] = s
   477  	return nil
   478  }
   479  
   480  func (router *router) ServeRequest(ctx context.Context, r server.Request, rsp server.Response) error {
   481  	sending := new(sync.Mutex)
   482  	service, mtype, req, argv, replyv, keepReading, err := router.readRequest(r)
   483  	if err != nil {
   484  		if !keepReading {
   485  			return err
   486  		}
   487  		// send a response if we actually managed to read a header.
   488  		if req != nil {
   489  			router.freeRequest(req)
   490  		}
   491  		return err
   492  	}
   493  	return service.call(ctx, router, sending, mtype, req, argv, replyv, rsp.Codec())
   494  }
   495  
   496  func (router *router) NewSubscriber(topic string, handler interface{}, opts ...server.SubscriberOption) server.Subscriber {
   497  	return newSubscriber(topic, handler, opts...)
   498  }
   499  
   500  func (router *router) Subscribe(s server.Subscriber) error {
   501  	sub, ok := s.(*subscriber)
   502  	if !ok {
   503  		return fmt.Errorf("invalid subscriber: expected *subscriber")
   504  	}
   505  	if len(sub.handlers) == 0 {
   506  		return fmt.Errorf("invalid subscriber: no handler functions")
   507  	}
   508  
   509  	if err := validateSubscriber(sub); err != nil {
   510  		return err
   511  	}
   512  
   513  	router.su.Lock()
   514  	defer router.su.Unlock()
   515  
   516  	// append to subscribers
   517  	subs := router.subscribers[sub.Topic()]
   518  	subs = append(subs, sub)
   519  	router.subscribers[sub.Topic()] = subs
   520  
   521  	return nil
   522  }
   523  
   524  func (router *router) String() string {
   525  	return "mucp"
   526  }
   527  
   528  func (router *router) ProcessMessage(ctx context.Context, msg server.Message) (err error) {
   529  	defer func() {
   530  		// recover any panics
   531  		if r := recover(); r != nil {
   532  			log.Errorf("panic recovered: %v", r)
   533  			log.Error(string(debug.Stack()))
   534  			err = merrors.InternalServerError("go.micro.server", "panic recovered: %v", r)
   535  		}
   536  	}()
   537  
   538  	router.su.RLock()
   539  	// get the subscribers by topic
   540  	subs, ok := router.subscribers[msg.Topic()]
   541  	// unlock since we only need to get the subs
   542  	router.su.RUnlock()
   543  	if !ok {
   544  		return nil
   545  	}
   546  
   547  	var errResults []string
   548  
   549  	// we may have multiple subscribers for the topic
   550  	for _, sub := range subs {
   551  		// we may have multiple handlers per subscriber
   552  		for i := 0; i < len(sub.handlers); i++ {
   553  			// get the handler
   554  			handler := sub.handlers[i]
   555  
   556  			var isVal bool
   557  			var req reflect.Value
   558  
   559  			// check whether the handler is a pointer
   560  			if handler.reqType.Kind() == reflect.Ptr {
   561  				req = reflect.New(handler.reqType.Elem())
   562  			} else {
   563  				req = reflect.New(handler.reqType)
   564  				isVal = true
   565  			}
   566  
   567  			// if its a value get the element
   568  			if isVal {
   569  				req = req.Elem()
   570  			}
   571  
   572  			cc := msg.Codec()
   573  
   574  			// read the header. mostly a noop
   575  			if err = cc.ReadHeader(&codec.Message{}, codec.Event); err != nil {
   576  				return err
   577  			}
   578  
   579  			// read the body into the handler request value
   580  			if err = cc.ReadBody(req.Interface()); err != nil {
   581  				return err
   582  			}
   583  
   584  			// create the handler which will honour the SubscriberFunc type
   585  			fn := func(ctx context.Context, msg server.Message) error {
   586  				var vals []reflect.Value
   587  				if sub.typ.Kind() != reflect.Func {
   588  					vals = append(vals, sub.rcvr)
   589  				}
   590  				if handler.ctxType != nil {
   591  					vals = append(vals, reflect.ValueOf(ctx))
   592  				}
   593  
   594  				// values to pass the handler
   595  				vals = append(vals, reflect.ValueOf(msg.Payload()))
   596  
   597  				// execute the actuall call of the handler
   598  				returnValues := handler.method.Call(vals)
   599  				if rerr := returnValues[0].Interface(); rerr != nil {
   600  					err = rerr.(error)
   601  				}
   602  				return err
   603  			}
   604  
   605  			// wrap with subscriber wrappers
   606  			for i := len(router.subWrappers); i > 0; i-- {
   607  				fn = router.subWrappers[i-1](fn)
   608  			}
   609  
   610  			// create new rpc message
   611  			rpcMsg := &rpcMessage{
   612  				topic:       msg.Topic(),
   613  				contentType: msg.ContentType(),
   614  				payload:     req.Interface(),
   615  				codec:       msg.(*rpcMessage).codec,
   616  				header:      msg.Header(),
   617  				body:        msg.Body(),
   618  			}
   619  
   620  			// execute the message handler
   621  			if err = fn(ctx, rpcMsg); err != nil {
   622  				errResults = append(errResults, err.Error())
   623  			}
   624  		}
   625  	}
   626  
   627  	// if no errors just return
   628  	if len(errResults) > 0 {
   629  		err = merrors.InternalServerError("go.micro.server", "subscriber error: %v", strings.Join(errResults, "\n"))
   630  	}
   631  
   632  	return err
   633  }