github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/server/message_handle.go (about) 1 package server 2 3 import ( 4 "context" 5 "fmt" 6 "github.com/nyan233/littlerpc/core/common/errorhandler" 7 "github.com/nyan233/littlerpc/core/common/inters" 8 "github.com/nyan233/littlerpc/core/common/metadata" 9 metaDataUtil "github.com/nyan233/littlerpc/core/common/utils/metadata" 10 error2 "github.com/nyan233/littlerpc/core/protocol/error" 11 message2 "github.com/nyan233/littlerpc/core/protocol/message" 12 "github.com/nyan233/littlerpc/core/utils/convert" 13 reflect2 "github.com/nyan233/littlerpc/internal/reflect" 14 "reflect" 15 "runtime" 16 "strconv" 17 "time" 18 ) 19 20 var ( 21 hijackResultCache = []reflect.Value{reflect.ValueOf(errorhandler.Success)} 22 ) 23 24 // 过程中的副作用会导致msgOpt.Message在调用结束之前被放回pasrser中 25 func (s *Server) messageKeepAlive(msgOpt *messageOpt) { 26 defer func() { 27 msgOpt.Free() 28 msgOpt.FreePluginCtx() 29 }() 30 if err := msgOpt.RealPayload(); err != nil { 31 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgOpt.Message.GetMsgId(), s.eHandle.LWarpErrorDesc( 32 err, "keep-alive get real payload failed")) 33 return 34 } 35 msgOpt.Message.SetMsgType(message2.Pong) 36 cfg := s.config.Load() 37 if cfg.KeepAlive { 38 err := msgOpt.Conn.SetDeadline(time.Now().Add(cfg.KeepAliveTimeout)) 39 if err != nil { 40 s.logger.Error("LRPC: connection set deadline failed: %v", err) 41 _ = msgOpt.Conn.Close() 42 return 43 } 44 } 45 s.encodeAndSendMsg(msgOpt, msgOpt.Message, nil, false) 46 } 47 48 // 过程中的副作用会导致msgOpt.Message在调用结束之前被放回pasrser中 49 func (s *Server) messageContextCancel(msgOpt *messageOpt) { 50 defer func() { 51 msgOpt.Free() 52 msgOpt.FreePluginCtx() 53 }() 54 if err := msgOpt.RealPayload(); err != nil { 55 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgOpt.Message.GetMsgId(), s.eHandle.LWarpErrorDesc( 56 err, "context-cancel get real payload failed")) 57 return 58 } 59 ctxIdStr, ok := msgOpt.Message.MetaData.LoadOk(message2.ContextId) 60 if !ok { 61 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgOpt.Message.GetMsgId(), error2.LWarpStdError( 62 errorhandler.ContextNotFound, fmt.Sprintf("contextId : %s", ctxIdStr))) 63 } 64 ctxId, err := strconv.ParseUint(ctxIdStr, 10, 64) 65 if err != nil { 66 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgOpt.Message.GetMsgId(), error2.LWarpStdError( 67 errorhandler.ErrServer, err.Error())) 68 } 69 err = msgOpt.Desc.ctxManager.CancelContext(ctxId) 70 if err != nil { 71 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgOpt.Message.GetMsgId(), error2.LWarpStdError( 72 errorhandler.ErrServer, err.Error())) 73 return 74 } 75 s.encodeAndSendMsg(msgOpt, msgOpt.Message, nil, false) 76 } 77 78 // 过程中的副作用会导致msgOpt.Message在调用结束之前被放回pasrser中 79 func (s *Server) messageCall(msgOpt *messageOpt, desc *connSourceDesc) { 80 msgId := msgOpt.Message.GetMsgId() 81 var err error 82 defer func() { 83 if err != nil { 84 msgOpt.Free() 85 msgOpt.FreePluginCtx() 86 } 87 }() 88 err = msgOpt.RealPayload() 89 if err != nil { 90 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err.(error2.LErrorDesc)) 91 return 92 } 93 err = msgOpt.checkService() 94 if err != nil { 95 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err.(error2.LErrorDesc)) 96 return 97 } 98 callHandler := s.callHandleUnit 99 if msgOpt.Hijack() { 100 callHandler = s.hijackCall 101 } else { 102 err = msgOpt.Check() 103 if err != nil { 104 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err.(error2.LErrorDesc)) 105 return 106 } 107 } 108 switch { 109 case msgOpt.Service.Option.SyncCall: 110 callHandler(msgOpt) 111 case msgOpt.Service.Option.UseRawGoroutine: 112 go func() { 113 callHandler(msgOpt) 114 }() 115 default: 116 err = s.taskPool.Push(msgOpt.Message.GetServiceName(), func() { 117 callHandler(msgOpt) 118 }) 119 if err != nil { 120 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, s.eHandle.LWarpErrorDesc(errorhandler.ErrServer, err.Error())) 121 } 122 } 123 } 124 125 // 提供用于任务池的处理调用用户过程的单元 126 // 因为用户过程可能会有阻塞操作 127 func (s *Server) callHandleUnit(msgOpt *messageOpt) { 128 msgId := msgOpt.Message.GetMsgId() 129 msgOpt.Free() 130 131 messageBuffer := s.pool.TakeMessagePool() 132 msg := messageBuffer.Get().(*message2.Message) 133 msg.Reset() 134 defer func() { 135 message2.ResetMsg(msg, false, true, true, 1024) 136 messageBuffer.Put(msg) 137 msgOpt.FreePluginCtx() 138 }() 139 callResult, cErr := s.handleCall(msgOpt.Service, msgOpt.CallArgs) 140 // context存在时且未被取消, 则在调用结束之后取消 141 if msgOpt.Service.SupportContext && msgOpt.CallArgs[0].Interface().(context.Context).Err() == nil && msgOpt.Cancel != nil { 142 msgOpt.Cancel() 143 } 144 145 if cErr == nil && len(callResult) == 0 { 146 // TODO v0.4.x计划删除 147 // 函数在没有返回error则填充nil 148 callResult = append(callResult, reflect.ValueOf(nil)) 149 } 150 err := s.pManager.AfterCall4S(msgOpt.PCtx, msgOpt.CallArgs, callResult, cErr) 151 // AfterCall4S()之后不会再被使用, 可以回收参数 152 if msgOpt.Service.Option.CompleteReUsage { 153 for i := metaDataUtil.InputOffset(msgOpt.Service); i < len(msgOpt.CallArgs); i++ { 154 msgOpt.CallArgs[i].Interface().(inters.Reset).Reset() 155 } 156 msgOpt.Service.Pool.Put(msgOpt.CallArgs) 157 // 置空, 防止放回池中时被其它goroutine重新引用而导致数据竞争, 导致难以排查 158 msgOpt.CallArgs = nil 159 } 160 if err != nil { 161 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err) 162 return 163 } 164 if cErr != nil { 165 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, cErr) 166 return 167 } 168 s.reply(msgOpt, msg, msgId, callResult) 169 } 170 171 func (s *Server) hijackCall(msgOpt *messageOpt) { 172 defer msgOpt.Free() 173 msgId := msgOpt.Message.GetMsgId() 174 ctx, err := msgOpt.getContext() 175 if err != nil { 176 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err) 177 return 178 } 179 localPool := s.pool.TakeMessagePool() 180 replyMsg := localPool.Get().(*message2.Message) 181 replyMsg.Reset() 182 stub := &Stub{ 183 opt: msgOpt, 184 reply: replyMsg, 185 Context: ctx, 186 } 187 defer localPool.Put(replyMsg) 188 err = s.handleCallOnHijack(msgOpt.Service, stub) 189 if err != nil { 190 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err) 191 return 192 } 193 if msgOpt.Service.SupportContext && stub.Context.Err() == nil && msgOpt.Cancel != nil { 194 msgOpt.Cancel() 195 } 196 var hijackResults []reflect.Value 197 if stub.callErr == nil { 198 hijackResults = hijackResultCache 199 } else { 200 hijackResults = []reflect.Value{reflect.ValueOf(stub.callErr)} 201 } 202 err = s.pManager.AfterCall4S(msgOpt.PCtx, nil, hijackResults, nil) 203 if err != nil { 204 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msgId, err) 205 return 206 } 207 s.replyOnHijack(msgOpt, replyMsg, msgId, stub.callErr) 208 } 209 210 func (s *Server) replyOnHijack(msgOpt *messageOpt, msg *message2.Message, msgId uint64, callErr error) { 211 msgOpt.initReplyMsg(msg, msgId) 212 err := s.setErr(msg, callErr) 213 s.encodeAndSendMsg(msgOpt, msg, err, true) 214 } 215 216 func (s *Server) reply(msgOpt *messageOpt, msg *message2.Message, msgId uint64, results []reflect.Value) { 217 msgOpt.initReplyMsg(msg, msgId) 218 // 处理用户过程返回的错误,v0.30开始规定每个符合规范的API最后一个返回值是error接口 219 lErr := s.setErrResult(msg, results[len(results)-1]) 220 if lErr != nil { 221 s.handleError(msgOpt.Conn, msgOpt.Desc.Writer, msg.GetMsgId(), lErr) 222 return 223 } 224 err := s.handleResult(msgOpt, msg, results) 225 s.encodeAndSendMsg(msgOpt, msg, err, true) 226 } 227 228 func (s *Server) handleCall(service *metadata.Process, args []reflect.Value) (results []reflect.Value, err error2.LErrorDesc) { 229 defer s.processCallRecover(&err) 230 results = service.Value.Call(args) 231 return 232 } 233 234 func (s *Server) handleCallOnHijack(service *metadata.Process, stub *Stub) (err error2.LErrorDesc) { 235 defer s.processCallRecover(&err) 236 fun := *(*func(stub *Stub))(service.Hijacker) 237 stub.setup() 238 fun(stub) 239 return nil 240 } 241 242 func (s *Server) processCallRecover(err *error2.LErrorDesc) { 243 e := recover() 244 if e == nil { 245 return 246 } 247 var printStr string 248 switch e.(type) { 249 case error2.LErrorDesc: 250 *err = e.(error2.LErrorDesc) 251 printStr = (*err).Error() 252 case error: 253 iErr := e.(error) 254 *err = s.eHandle.LNewErrorDesc(error2.Unknown, iErr.Error()) 255 printStr = iErr.Error() 256 case string: 257 *err = s.eHandle.LNewErrorDesc(error2.Unknown, e.(string)) 258 printStr = e.(string) 259 default: 260 printStr = fmt.Sprintf("%v", e) 261 *err = s.eHandle.LNewErrorDesc(error2.Unknown, printStr) 262 } 263 var stack [4096]byte 264 size := runtime.Stack(stack[:], false) 265 s.logger.Warn("callee panic : %s\n%s", printStr, convert.BytesToString(stack[:size])) 266 return 267 } 268 269 // 将用户过程的返回结果集序列化为可传输的json数据 270 func (s *Server) handleResult(msgOpt *messageOpt, msg *message2.Message, callResult []reflect.Value) error2.LErrorDesc { 271 for _, v := range callResult[:len(callResult)-1] { 272 // NOTE : 对于指针类型或者隐含指针的类型, 他检查用户过程是否返回nil 273 // NOTE : 对于非指针的值传递类型, 它检查该类型是否是零值 274 // 借助这个哨兵条件可以减少零值的序列化/网络开销 275 if v.IsZero() { 276 // 添加返回参数的标记, 这是因为在多个返回参数可能出现以下的情况 277 // (Value),(Value2),(nil),(Zero) 278 // 在以上情况下简单地忽略并不是一个好主意(会导致返回值反序列化异常), 所以需要一个标记让客户端知道 279 msg.AppendPayloads(make([]byte, 0)) 280 continue 281 } 282 var eface = v.Interface() 283 // 可替换的Codec已经不需要Any包装器了 284 bytes, err := msgOpt.Codec.Marshal(eface) 285 if err != nil { 286 return s.eHandle.LWarpErrorDesc(errorhandler.ErrCodecMarshalError, err.Error()) 287 } 288 msg.AppendPayloads(bytes) 289 } 290 return nil 291 } 292 293 // 必须在其结果集中首先处理错误在处理其余结果 294 func (s *Server) setErrResult(msg *message2.Message, errResult reflect.Value) error2.LErrorDesc { 295 val := reflect2.ToValueTypeEface(errResult) 296 interErr, _ := val.(error) 297 return s.setErr(msg, interErr) 298 } 299 300 func (s *Server) setErr(msg *message2.Message, interErr error) error2.LErrorDesc { 301 // 无错误 302 if interErr == error(nil) { 303 msg.MetaData.Store(message2.ErrorCode, strconv.Itoa(errorhandler.Success.Code())) 304 msg.MetaData.Store(message2.ErrorMessage, errorhandler.Success.Message()) 305 return nil 306 } 307 // 检查是否实现了自定义错误的接口 308 desc, ok := interErr.(error2.LErrorDesc) 309 if ok { 310 msg.MetaData.Store(message2.ErrorCode, strconv.Itoa(desc.Code())) 311 msg.MetaData.Store(message2.ErrorMessage, desc.Message()) 312 bytes, err := desc.MarshalMores() 313 if err != nil { 314 return s.eHandle.LWarpErrorDesc( 315 errorhandler.ErrCodecMarshalError, 316 fmt.Sprintf("%s : %s", message2.ErrorMore, err.Error())) 317 } 318 msg.MetaData.Store(message2.ErrorMore, convert.BytesToString(bytes)) 319 return nil 320 } 321 err, ok := interErr.(error) 322 // NOTE 按理来说, 在正常情况下!ok这个分支不应该被激活, 检查每个过程返回error是Elem的责任 323 // NOTE 建立这个分支是防止用户自作聪明使用一些Hack的手段绕过了Elem的检查 324 if !ok { 325 return s.eHandle.LNewErrorDesc(error2.UnsafeOption, "Server.RegisterClass no checker on error") 326 } 327 msg.MetaData.Store(message2.ErrorCode, strconv.Itoa(error2.Unknown)) 328 msg.MetaData.Store(message2.ErrorMessage, err.Error()) 329 return nil 330 }