github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/server/message_handle.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"github.com/nyan233/littlerpc/core/common/errorhandler"
     7  	"github.com/nyan233/littlerpc/core/common/inters"
     8  	"github.com/nyan233/littlerpc/core/common/metadata"
     9  	metaDataUtil "github.com/nyan233/littlerpc/core/common/utils/metadata"
    10  	error2 "github.com/nyan233/littlerpc/core/protocol/error"
    11  	message2 "github.com/nyan233/littlerpc/core/protocol/message"
    12  	"github.com/nyan233/littlerpc/core/utils/convert"
    13  	reflect2 "github.com/nyan233/littlerpc/internal/reflect"
    14  	"reflect"
    15  	"runtime"
    16  	"strconv"
    17  	"time"
    18  )
    19  
    20  var (
    21  	hijackResultCache = []reflect.Value{reflect.ValueOf(errorhandler.Success)}
    22  )
    23  
    24  // 过程中的副作用会导致msgOpt.Message在调用结束之前被放回pasrser中
    25  func (s *Server) messageKeepAlive(msgOpt *messageOpt) {
    26  	defer func() {
    27  		msgOpt.Free()
    28  		msgOpt.FreePluginCtx()
    29  	}()
    30  	if err := msgOpt.RealPayload(); err != nil {
    31  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgOpt.Message.GetMsgId(), s.eHandle.LWarpErrorDesc(
    32  			err, "keep-alive get real payload failed"))
    33  		return
    34  	}
    35  	msgOpt.Message.SetMsgType(message2.Pong)
    36  	cfg := s.config.Load()
    37  	if cfg.KeepAlive {
    38  		err := msgOpt.Conn.SetDeadline(time.Now().Add(cfg.KeepAliveTimeout))
    39  		if err != nil {
    40  			s.logger.Error("LRPC: connection set deadline failed: %v", err)
    41  			_ = msgOpt.Conn.Close()
    42  			return
    43  		}
    44  	}
    45  	s.encodeAndSendMsg(msgOpt, msgOpt.Message, nil, false)
    46  }
    47  
    48  // 过程中的副作用会导致msgOpt.Message在调用结束之前被放回pasrser中
    49  func (s *Server) messageContextCancel(msgOpt *messageOpt) {
    50  	defer func() {
    51  		msgOpt.Free()
    52  		msgOpt.FreePluginCtx()
    53  	}()
    54  	if err := msgOpt.RealPayload(); err != nil {
    55  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgOpt.Message.GetMsgId(), s.eHandle.LWarpErrorDesc(
    56  			err, "context-cancel get real payload failed"))
    57  		return
    58  	}
    59  	ctxIdStr, ok := msgOpt.Message.MetaData.LoadOk(message2.ContextId)
    60  	if !ok {
    61  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgOpt.Message.GetMsgId(), error2.LWarpStdError(
    62  			errorhandler.ContextNotFound, fmt.Sprintf("contextId : %s", ctxIdStr)))
    63  	}
    64  	ctxId, err := strconv.ParseUint(ctxIdStr, 10, 64)
    65  	if err != nil {
    66  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgOpt.Message.GetMsgId(), error2.LWarpStdError(
    67  			errorhandler.ErrServer, err.Error()))
    68  	}
    69  	err = msgOpt.Desc.ctxManager.CancelContext(ctxId)
    70  	if err != nil {
    71  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgOpt.Message.GetMsgId(), error2.LWarpStdError(
    72  			errorhandler.ErrServer, err.Error()))
    73  		return
    74  	}
    75  	s.encodeAndSendMsg(msgOpt, msgOpt.Message, nil, false)
    76  }
    77  
    78  // 过程中的副作用会导致msgOpt.Message在调用结束之前被放回pasrser中
    79  func (s *Server) messageCall(msgOpt *messageOpt, desc *connSourceDesc) {
    80  	msgId := msgOpt.Message.GetMsgId()
    81  	var err error
    82  	defer func() {
    83  		if err != nil {
    84  			msgOpt.Free()
    85  			msgOpt.FreePluginCtx()
    86  		}
    87  	}()
    88  	err = msgOpt.RealPayload()
    89  	if err != nil {
    90  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err.(error2.LErrorDesc))
    91  		return
    92  	}
    93  	err = msgOpt.checkService()
    94  	if err != nil {
    95  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err.(error2.LErrorDesc))
    96  		return
    97  	}
    98  	callHandler := s.callHandleUnit
    99  	if msgOpt.Hijack() {
   100  		callHandler = s.hijackCall
   101  	} else {
   102  		err = msgOpt.Check()
   103  		if err != nil {
   104  			s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err.(error2.LErrorDesc))
   105  			return
   106  		}
   107  	}
   108  	switch {
   109  	case msgOpt.Service.Option.SyncCall:
   110  		callHandler(msgOpt)
   111  	case msgOpt.Service.Option.UseRawGoroutine:
   112  		go func() {
   113  			callHandler(msgOpt)
   114  		}()
   115  	default:
   116  		err = s.taskPool.Push(msgOpt.Message.GetServiceName(), func() {
   117  			callHandler(msgOpt)
   118  		})
   119  		if err != nil {
   120  			s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, s.eHandle.LWarpErrorDesc(errorhandler.ErrServer, err.Error()))
   121  		}
   122  	}
   123  }
   124  
   125  // 提供用于任务池的处理调用用户过程的单元
   126  // 因为用户过程可能会有阻塞操作
   127  func (s *Server) callHandleUnit(msgOpt *messageOpt) {
   128  	msgId := msgOpt.Message.GetMsgId()
   129  	msgOpt.Free()
   130  
   131  	messageBuffer := s.pool.TakeMessagePool()
   132  	msg := messageBuffer.Get().(*message2.Message)
   133  	msg.Reset()
   134  	defer func() {
   135  		message2.ResetMsg(msg, false, true, true, 1024)
   136  		messageBuffer.Put(msg)
   137  		msgOpt.FreePluginCtx()
   138  	}()
   139  	callResult, cErr := s.handleCall(msgOpt.Service, msgOpt.CallArgs)
   140  	// context存在时且未被取消, 则在调用结束之后取消
   141  	if msgOpt.Service.SupportContext && msgOpt.CallArgs[0].Interface().(context.Context).Err() == nil && msgOpt.Cancel != nil {
   142  		msgOpt.Cancel()
   143  	}
   144  
   145  	if cErr == nil && len(callResult) == 0 {
   146  		// TODO v0.4.x计划删除
   147  		// 函数在没有返回error则填充nil
   148  		callResult = append(callResult, reflect.ValueOf(nil))
   149  	}
   150  	err := s.pManager.AfterCall4S(msgOpt.PCtx, msgOpt.CallArgs, callResult, cErr)
   151  	// AfterCall4S()之后不会再被使用, 可以回收参数
   152  	if msgOpt.Service.Option.CompleteReUsage {
   153  		for i := metaDataUtil.InputOffset(msgOpt.Service); i < len(msgOpt.CallArgs); i++ {
   154  			msgOpt.CallArgs[i].Interface().(inters.Reset).Reset()
   155  		}
   156  		msgOpt.Service.Pool.Put(msgOpt.CallArgs)
   157  		// 置空, 防止放回池中时被其它goroutine重新引用而导致数据竞争, 导致难以排查
   158  		msgOpt.CallArgs = nil
   159  	}
   160  	if err != nil {
   161  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err)
   162  		return
   163  	}
   164  	if cErr != nil {
   165  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, cErr)
   166  		return
   167  	}
   168  	s.reply(msgOpt, msg, msgId, callResult)
   169  }
   170  
   171  func (s *Server) hijackCall(msgOpt *messageOpt) {
   172  	defer msgOpt.Free()
   173  	msgId := msgOpt.Message.GetMsgId()
   174  	ctx, err := msgOpt.getContext()
   175  	if err != nil {
   176  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err)
   177  		return
   178  	}
   179  	localPool := s.pool.TakeMessagePool()
   180  	replyMsg := localPool.Get().(*message2.Message)
   181  	replyMsg.Reset()
   182  	stub := &Stub{
   183  		opt:     msgOpt,
   184  		reply:   replyMsg,
   185  		Context: ctx,
   186  	}
   187  	defer localPool.Put(replyMsg)
   188  	err = s.handleCallOnHijack(msgOpt.Service, stub)
   189  	if err != nil {
   190  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err)
   191  		return
   192  	}
   193  	if msgOpt.Service.SupportContext && stub.Context.Err() == nil && msgOpt.Cancel != nil {
   194  		msgOpt.Cancel()
   195  	}
   196  	var hijackResults []reflect.Value
   197  	if stub.callErr == nil {
   198  		hijackResults = hijackResultCache
   199  	} else {
   200  		hijackResults = []reflect.Value{reflect.ValueOf(stub.callErr)}
   201  	}
   202  	err = s.pManager.AfterCall4S(msgOpt.PCtx, nil, hijackResults, nil)
   203  	if err != nil {
   204  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err)
   205  		return
   206  	}
   207  	s.replyOnHijack(msgOpt, replyMsg, msgId, stub.callErr)
   208  }
   209  
   210  func (s *Server) replyOnHijack(msgOpt *messageOpt, msg *message2.Message, msgId uint64, callErr error) {
   211  	msgOpt.initReplyMsg(msg, msgId)
   212  	err := s.setErr(msg, callErr)
   213  	s.encodeAndSendMsg(msgOpt, msg, err, true)
   214  }
   215  
   216  func (s *Server) reply(msgOpt *messageOpt, msg *message2.Message, msgId uint64, results []reflect.Value) {
   217  	msgOpt.initReplyMsg(msg, msgId)
   218  	// 处理用户过程返回的错误,v0.30开始规定每个符合规范的API最后一个返回值是error接口
   219  	lErr := s.setErrResult(msg, results[len(results)-1])
   220  	if lErr != nil {
   221  		s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msg.GetMsgId(), lErr)
   222  		return
   223  	}
   224  	err := s.handleResult(msgOpt, msg, results)
   225  	s.encodeAndSendMsg(msgOpt, msg, err, true)
   226  }
   227  
   228  func (s *Server) handleCall(service *metadata.Process, args []reflect.Value) (results []reflect.Value, err error2.LErrorDesc) {
   229  	defer s.processCallRecover(&err)
   230  	results = service.Value.Call(args)
   231  	return
   232  }
   233  
   234  func (s *Server) handleCallOnHijack(service *metadata.Process, stub *Stub) (err error2.LErrorDesc) {
   235  	defer s.processCallRecover(&err)
   236  	fun := *(*func(stub *Stub))(service.Hijacker)
   237  	stub.setup()
   238  	fun(stub)
   239  	return nil
   240  }
   241  
   242  func (s *Server) processCallRecover(err *error2.LErrorDesc) {
   243  	e := recover()
   244  	if e == nil {
   245  		return
   246  	}
   247  	var printStr string
   248  	switch e.(type) {
   249  	case error2.LErrorDesc:
   250  		*err = e.(error2.LErrorDesc)
   251  		printStr = (*err).Error()
   252  	case error:
   253  		iErr := e.(error)
   254  		*err = s.eHandle.LNewErrorDesc(error2.Unknown, iErr.Error())
   255  		printStr = iErr.Error()
   256  	case string:
   257  		*err = s.eHandle.LNewErrorDesc(error2.Unknown, e.(string))
   258  		printStr = e.(string)
   259  	default:
   260  		printStr = fmt.Sprintf("%v", e)
   261  		*err = s.eHandle.LNewErrorDesc(error2.Unknown, printStr)
   262  	}
   263  	var stack [4096]byte
   264  	size := runtime.Stack(stack[:], false)
   265  	s.logger.Warn("callee panic : %s\n%s", printStr, convert.BytesToString(stack[:size]))
   266  	return
   267  }
   268  
   269  // 将用户过程的返回结果集序列化为可传输的json数据
   270  func (s *Server) handleResult(msgOpt *messageOpt, msg *message2.Message, callResult []reflect.Value) error2.LErrorDesc {
   271  	for _, v := range callResult[:len(callResult)-1] {
   272  		// NOTE : 对于指针类型或者隐含指针的类型, 他检查用户过程是否返回nil
   273  		// NOTE : 对于非指针的值传递类型, 它检查该类型是否是零值
   274  		// 借助这个哨兵条件可以减少零值的序列化/网络开销
   275  		if v.IsZero() {
   276  			// 添加返回参数的标记, 这是因为在多个返回参数可能出现以下的情况
   277  			// (Value),(Value2),(nil),(Zero)
   278  			// 在以上情况下简单地忽略并不是一个好主意(会导致返回值反序列化异常), 所以需要一个标记让客户端知道
   279  			msg.AppendPayloads(make([]byte, 0))
   280  			continue
   281  		}
   282  		var eface = v.Interface()
   283  		// 可替换的Codec已经不需要Any包装器了
   284  		bytes, err := msgOpt.Codec.Marshal(eface)
   285  		if err != nil {
   286  			return s.eHandle.LWarpErrorDesc(errorhandler.ErrCodecMarshalError, err.Error())
   287  		}
   288  		msg.AppendPayloads(bytes)
   289  	}
   290  	return nil
   291  }
   292  
   293  // 必须在其结果集中首先处理错误在处理其余结果
   294  func (s *Server) setErrResult(msg *message2.Message, errResult reflect.Value) error2.LErrorDesc {
   295  	val := reflect2.ToValueTypeEface(errResult)
   296  	interErr, _ := val.(error)
   297  	return s.setErr(msg, interErr)
   298  }
   299  
   300  func (s *Server) setErr(msg *message2.Message, interErr error) error2.LErrorDesc {
   301  	// 无错误
   302  	if interErr == error(nil) {
   303  		msg.MetaData.Store(message2.ErrorCode, strconv.Itoa(errorhandler.Success.Code()))
   304  		msg.MetaData.Store(message2.ErrorMessage, errorhandler.Success.Message())
   305  		return nil
   306  	}
   307  	// 检查是否实现了自定义错误的接口
   308  	desc, ok := interErr.(error2.LErrorDesc)
   309  	if ok {
   310  		msg.MetaData.Store(message2.ErrorCode, strconv.Itoa(desc.Code()))
   311  		msg.MetaData.Store(message2.ErrorMessage, desc.Message())
   312  		bytes, err := desc.MarshalMores()
   313  		if err != nil {
   314  			return s.eHandle.LWarpErrorDesc(
   315  				errorhandler.ErrCodecMarshalError,
   316  				fmt.Sprintf("%s : %s", message2.ErrorMore, err.Error()))
   317  		}
   318  		msg.MetaData.Store(message2.ErrorMore, convert.BytesToString(bytes))
   319  		return nil
   320  	}
   321  	err, ok := interErr.(error)
   322  	// NOTE 按理来说, 在正常情况下!ok这个分支不应该被激活, 检查每个过程返回error是Elem的责任
   323  	// NOTE 建立这个分支是防止用户自作聪明使用一些Hack的手段绕过了Elem的检查
   324  	if !ok {
   325  		return s.eHandle.LNewErrorDesc(error2.UnsafeOption, "Server.RegisterClass no checker on error")
   326  	}
   327  	msg.MetaData.Store(message2.ErrorCode, strconv.Itoa(error2.Unknown))
   328  	msg.MetaData.Store(message2.ErrorMessage, err.Error())
   329  	return nil
   330  }