github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/server/msg_opt.go (about) 1 package server 2 3 import ( 4 "context" 5 "fmt" 6 "github.com/nyan233/littlerpc/core/common/check" 7 rContext "github.com/nyan233/littlerpc/core/common/context" 8 "github.com/nyan233/littlerpc/core/common/errorhandler" 9 "github.com/nyan233/littlerpc/core/common/metadata" 10 "github.com/nyan233/littlerpc/core/common/msgparser" 11 "github.com/nyan233/littlerpc/core/common/stream" 12 "github.com/nyan233/littlerpc/core/common/transport" 13 metaDataUtil "github.com/nyan233/littlerpc/core/common/utils/metadata" 14 "github.com/nyan233/littlerpc/core/container" 15 "github.com/nyan233/littlerpc/core/middle/codec" 16 "github.com/nyan233/littlerpc/core/middle/packer" 17 "github.com/nyan233/littlerpc/core/middle/plugin" 18 perror "github.com/nyan233/littlerpc/core/protocol/error" 19 "github.com/nyan233/littlerpc/core/protocol/message" 20 "github.com/nyan233/littlerpc/core/protocol/message/mux" 21 reflect2 "github.com/nyan233/littlerpc/internal/reflect" 22 "reflect" 23 "strconv" 24 "time" 25 ) 26 27 // 该类型拥有的方法都有很多的副作用, 请谨慎 28 type messageOpt struct { 29 Server *Server 30 Header byte 31 Codec codec.Codec 32 Packer packer.Packer 33 Message *message.Message 34 freeFunc func(msg *message.Message) 35 Service *metadata.Process 36 Conn transport.ConnAdapter 37 Desc *connSourceDesc 38 // 弃用原来的Context-Id, Context-Id时会为每次请求创建一个新的context 39 // Cancel func取消的是从context-id创建的原始context中派生的, 因此并没有context-id 40 Cancel context.CancelFunc 41 CallArgs []reflect.Value 42 PCtx *plugin.Context 43 } 44 45 func newConnDesc(s *Server, msg msgparser.ParserMessage, conn transport.ConnAdapter, desc *connSourceDesc) *messageOpt { 46 opt := &messageOpt{ 47 Server: s, 48 Message: msg.Message, 49 Header: msg.Header, 50 Conn: conn, 51 Desc: desc, 52 } 53 if opt.Server.pManager.Size() <= 0 { 54 opt.PCtx = nil 55 } else { 56 opt.PCtx = s.pManager.GetContext() 57 opt.PCtx.PluginContext = injectPluginContext(desc.cacheCtx, msg.Message.GetMsgType(), msg.Message.GetServiceName(), time.Now()) 58 opt.PCtx.Logger = s.logger 59 opt.PCtx.EHandler = s.eHandle 60 } 61 return opt 62 } 63 64 func (c *messageOpt) SelectCodecAndEncoder() { 65 // 根据读取的头信息初始化一些需要的Codec/Packer 66 c.Codec = codec.Get(c.Message.MetaData.Load(message.CodecScheme)) 67 c.Packer = packer.Get(c.Message.MetaData.Load(message.PackerScheme)) 68 if c.Codec == nil { 69 c.Codec = codec.Get(message.DefaultCodec) 70 } 71 if c.Packer == nil { 72 c.Packer = packer.Get(message.DefaultPacker) 73 } 74 } 75 76 // RealPayload 获取真正的Payload, 如果有压缩则解压 77 func (c *messageOpt) RealPayload() perror.LErrorDesc { 78 if c.Packer.Scheme() != "text" { 79 bytes, err := c.Packer.UnPacket(c.Message.Payloads()) 80 if err != nil { 81 return c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrServer, err.Error()) 82 } 83 c.Message.SetPayloads(bytes) 84 } 85 if err := c.Server.pManager.Receive4S(c.PCtx, c.Message); err != nil { 86 return err 87 } 88 return nil 89 } 90 91 // Free 不允许释放nil message, 或者重复释放, 否则panic 92 func (c *messageOpt) Free() { 93 if c.Message == nil { 94 panic("release not found message or retry release") 95 } 96 c.freeFunc(c.Message) 97 c.Message = nil 98 } 99 100 func (c *messageOpt) FreePluginCtx() { 101 if c.PCtx == nil { 102 return 103 } 104 ctx := c.PCtx 105 c.PCtx = nil 106 c.Server.pManager.FreeContext(ctx) 107 } 108 109 func (c *messageOpt) setFreeFunc(f func(msg *message.Message)) { 110 c.freeFunc = f 111 } 112 113 func (c *messageOpt) Hijack() bool { 114 return c.Service.Hijack 115 } 116 117 // UseMux TODO: 计划删除, 这样做并不能判断是否使用了Mux 118 func (c *messageOpt) UseMux() bool { 119 return c.Message.First() == mux.Enabled 120 } 121 122 func (c *messageOpt) Check() perror.LErrorDesc { 123 err := c.checkService() 124 if err != nil { 125 return err 126 } 127 // 从客户端校验并获得合法的调用参数 128 callArgs, lErr := c.checkCallArgs() 129 if err := c.Server.pManager.Call4S(c.PCtx, callArgs, lErr); err != nil { 130 return c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrPlugin, err) 131 } 132 if lErr != nil { 133 return c.Server.eHandle.LWarpErrorDesc(lErr, "arguments check failed") 134 } 135 c.CallArgs = callArgs 136 return nil 137 } 138 139 func (c *messageOpt) checkService() perror.LErrorDesc { 140 if c.Service != nil { 141 return nil 142 } 143 // 序列化完之后才确定调用名 144 // MethodName : Hello.Hello : receiver:methodName 145 service, ok := c.Server.services.LoadOk(c.Message.GetServiceName()) 146 if !ok { 147 return c.Server.eHandle.LWarpErrorDesc( 148 errorhandler.ServiceNotfound, c.Message.GetServiceName()) 149 } 150 c.Service = service 151 return nil 152 } 153 154 func (c *messageOpt) checkCallArgs() (values []reflect.Value, err perror.LErrorDesc) { 155 // 去除接收者之后的输入参数长度 156 // 校验客户端传递的参数和服务端是否一致 157 iter := c.Message.PayloadsIterator() 158 if nInput := len(c.Service.ArgsType) - metaDataUtil.InputOffset(c.Service); nInput != iter.Tail() { 159 return nil, c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrServer, 160 "client input args number no equal server", 161 fmt.Sprintf("Client : %d", iter.Tail()), fmt.Sprintf("Server : %d", nInput)) 162 } 163 // 哨兵条件, 过程不要求任何输入时即可以提前结束 164 if len(c.Service.ArgsType) == 0 { 165 return 166 } 167 defer func() { 168 if err == nil { 169 return 170 } 171 if c.Cancel != nil { 172 c.Cancel() 173 } 174 }() 175 var callArgs []reflect.Value 176 var inputStart int 177 if c.Service.Option.CompleteReUsage { 178 callArgs = c.Service.Pool.Get().([]reflect.Value) 179 defer func() { 180 if err != nil { 181 c.Service.Pool.Put(&callArgs) 182 } 183 }() 184 inputStart, err = c.checkContextAndStream(callArgs, true) 185 if err != nil { 186 return 187 } 188 } else { 189 callArgs = reflect2.FuncInputTypeListReturnValue(c.Service.ArgsType, 0, func(i int) bool { 190 if len(iter.Take()) == 0 { 191 return true 192 } 193 return false 194 }, true) 195 inputStart, err = c.checkContextAndStream(callArgs, true) 196 if err != nil { 197 return 198 } 199 } 200 iter.Reset() 201 for i := inputStart; i < len(callArgs) && iter.Next(); i++ { 202 callArg, err := check.UnMarshalFromUnsafe(c.Codec, iter.Take(), callArgs[i].Interface()) 203 if err != nil { 204 return nil, c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrCodecMarshalError, err.Error()) 205 } 206 // 可以根据获取的参数类别的每一个参数的类型信息得到 207 // 所需的精确类型,所以不用再对变长的类型做处理 208 callArgs[i] = reflect.ValueOf(callArg) 209 } 210 return callArgs, nil 211 } 212 213 func (c *messageOpt) getContext() (context.Context, perror.LErrorDesc) { 214 ctx := context.Background() 215 ctxIdStr, ok := c.Message.MetaData.LoadOk(message.ContextId) 216 // 客户端携带context-id且对应的过程支持context时才注册context 217 // 为不支持context的过程注册时无意义的且可能会导致context泄漏 218 if ok && c.Service.SupportContext { 219 ctxId, err := strconv.ParseUint(ctxIdStr, 10, 64) 220 if err != nil { 221 return nil, c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrServer, err.Error()) 222 } 223 rawCtx, _ := c.Desc.ctxManager.RegisterContextCancel(ctxId) 224 ctx, c.Cancel = context.WithCancel(rawCtx) 225 if err != nil { 226 return nil, c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrServer, err.Error()) 227 } 228 } 229 ctx = rContext.WithLocalAddr(ctx, c.Desc.localAddr) 230 ctx = rContext.WithRemoteAddr(ctx, c.Desc.remoteAddr) 231 return ctx, nil 232 } 233 234 func (c *messageOpt) checkContextAndStream(callArgs container.Slice[reflect.Value], write bool) (offset int, err perror.LErrorDesc) { 235 ctx, err := c.getContext() 236 if err != nil { 237 return 0, err 238 } 239 callArgs.Reset() 240 switch { 241 case c.Service.SupportContext: 242 offset = 1 243 if write { 244 callArgs.AppendSingle(reflect.ValueOf(ctx)) 245 } 246 case c.Service.SupportContext && c.Service.SupportStream: 247 offset = 2 248 if write { 249 callArgs.AppendS(reflect.ValueOf(ctx), reflect.ValueOf(*new(stream.LStream))) 250 } 251 case c.Service.SupportStream: 252 offset = 1 253 if write { 254 callArgs.AppendSingle(reflect.ValueOf(*new(stream.LStream))) 255 } 256 default: 257 // 不支持context&stream 258 break 259 } 260 return 261 } 262 263 func (c *messageOpt) initReplyMsg(msg *message.Message, msgId uint64) { 264 msg.SetMsgType(message.Return) 265 msg.SetMsgId(msgId) 266 if c.Codec.Scheme() != message.DefaultCodec { 267 msg.MetaData.Store(message.CodecScheme, c.Codec.Scheme()) 268 } 269 if c.Packer.Scheme() != message.DefaultPacker { 270 msg.MetaData.Store(message.PackerScheme, c.Packer.Scheme()) 271 } 272 } 273 274 func injectPluginContext(ctx context.Context, msgType uint8, service string, start time.Time) context.Context { 275 ctx = rContext.WithInitData(ctx, &rContext.InitData{ 276 Start: start, 277 ServiceName: service, 278 MsgType: msgType, 279 }) 280 return ctx 281 }