github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/server/msg_opt.go (about)

     1  package server
     2  
     3  import (
     4  	"context"
     5  	"fmt"
     6  	"github.com/nyan233/littlerpc/core/common/check"
     7  	rContext "github.com/nyan233/littlerpc/core/common/context"
     8  	"github.com/nyan233/littlerpc/core/common/errorhandler"
     9  	"github.com/nyan233/littlerpc/core/common/metadata"
    10  	"github.com/nyan233/littlerpc/core/common/msgparser"
    11  	"github.com/nyan233/littlerpc/core/common/stream"
    12  	"github.com/nyan233/littlerpc/core/common/transport"
    13  	metaDataUtil "github.com/nyan233/littlerpc/core/common/utils/metadata"
    14  	"github.com/nyan233/littlerpc/core/container"
    15  	"github.com/nyan233/littlerpc/core/middle/codec"
    16  	"github.com/nyan233/littlerpc/core/middle/packer"
    17  	"github.com/nyan233/littlerpc/core/middle/plugin"
    18  	perror "github.com/nyan233/littlerpc/core/protocol/error"
    19  	"github.com/nyan233/littlerpc/core/protocol/message"
    20  	"github.com/nyan233/littlerpc/core/protocol/message/mux"
    21  	reflect2 "github.com/nyan233/littlerpc/internal/reflect"
    22  	"reflect"
    23  	"strconv"
    24  	"time"
    25  )
    26  
    27  // 该类型拥有的方法都有很多的副作用, 请谨慎
    28  type messageOpt struct {
    29  	Server   *Server
    30  	Header   byte
    31  	Codec    codec.Codec
    32  	Packer   packer.Packer
    33  	Message  *message.Message
    34  	freeFunc func(msg *message.Message)
    35  	Service  *metadata.Process
    36  	Conn     transport.ConnAdapter
    37  	Desc     *connSourceDesc
    38  	// 弃用原来的Context-Id, Context-Id时会为每次请求创建一个新的context
    39  	// Cancel func取消的是从context-id创建的原始context中派生的, 因此并没有context-id
    40  	Cancel   context.CancelFunc
    41  	CallArgs []reflect.Value
    42  	PCtx     *plugin.Context
    43  }
    44  
    45  func newConnDesc(s *Server, msg msgparser.ParserMessage, conn transport.ConnAdapter, desc *connSourceDesc) *messageOpt {
    46  	opt := &messageOpt{
    47  		Server:  s,
    48  		Message: msg.Message,
    49  		Header:  msg.Header,
    50  		Conn:    conn,
    51  		Desc:    desc,
    52  	}
    53  	if opt.Server.pManager.Size() <= 0 {
    54  		opt.PCtx = nil
    55  	} else {
    56  		opt.PCtx = s.pManager.GetContext()
    57  		opt.PCtx.PluginContext = injectPluginContext(desc.cacheCtx, msg.Message.GetMsgType(), msg.Message.GetServiceName(), time.Now())
    58  		opt.PCtx.Logger = s.logger
    59  		opt.PCtx.EHandler = s.eHandle
    60  	}
    61  	return opt
    62  }
    63  
    64  func (c *messageOpt) SelectCodecAndEncoder() {
    65  	// 根据读取的头信息初始化一些需要的Codec/Packer
    66  	c.Codec = codec.Get(c.Message.MetaData.Load(message.CodecScheme))
    67  	c.Packer = packer.Get(c.Message.MetaData.Load(message.PackerScheme))
    68  	if c.Codec == nil {
    69  		c.Codec = codec.Get(message.DefaultCodec)
    70  	}
    71  	if c.Packer == nil {
    72  		c.Packer = packer.Get(message.DefaultPacker)
    73  	}
    74  }
    75  
    76  // RealPayload 获取真正的Payload, 如果有压缩则解压
    77  func (c *messageOpt) RealPayload() perror.LErrorDesc {
    78  	if c.Packer.Scheme() != "text" {
    79  		bytes, err := c.Packer.UnPacket(c.Message.Payloads())
    80  		if err != nil {
    81  			return c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrServer, err.Error())
    82  		}
    83  		c.Message.SetPayloads(bytes)
    84  	}
    85  	if err := c.Server.pManager.Receive4S(c.PCtx, c.Message); err != nil {
    86  		return err
    87  	}
    88  	return nil
    89  }
    90  
    91  // Free 不允许释放nil message, 或者重复释放, 否则panic
    92  func (c *messageOpt) Free() {
    93  	if c.Message == nil {
    94  		panic("release not found message or retry release")
    95  	}
    96  	c.freeFunc(c.Message)
    97  	c.Message = nil
    98  }
    99  
   100  func (c *messageOpt) FreePluginCtx() {
   101  	if c.PCtx == nil {
   102  		return
   103  	}
   104  	ctx := c.PCtx
   105  	c.PCtx = nil
   106  	c.Server.pManager.FreeContext(ctx)
   107  }
   108  
   109  func (c *messageOpt) setFreeFunc(f func(msg *message.Message)) {
   110  	c.freeFunc = f
   111  }
   112  
   113  func (c *messageOpt) Hijack() bool {
   114  	return c.Service.Hijack
   115  }
   116  
   117  // UseMux TODO: 计划删除, 这样做并不能判断是否使用了Mux
   118  func (c *messageOpt) UseMux() bool {
   119  	return c.Message.First() == mux.Enabled
   120  }
   121  
   122  func (c *messageOpt) Check() perror.LErrorDesc {
   123  	err := c.checkService()
   124  	if err != nil {
   125  		return err
   126  	}
   127  	// 从客户端校验并获得合法的调用参数
   128  	callArgs, lErr := c.checkCallArgs()
   129  	if err := c.Server.pManager.Call4S(c.PCtx, callArgs, lErr); err != nil {
   130  		return c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrPlugin, err)
   131  	}
   132  	if lErr != nil {
   133  		return c.Server.eHandle.LWarpErrorDesc(lErr, "arguments check failed")
   134  	}
   135  	c.CallArgs = callArgs
   136  	return nil
   137  }
   138  
   139  func (c *messageOpt) checkService() perror.LErrorDesc {
   140  	if c.Service != nil {
   141  		return nil
   142  	}
   143  	// 序列化完之后才确定调用名
   144  	// MethodName : Hello.Hello : receiver:methodName
   145  	service, ok := c.Server.services.LoadOk(c.Message.GetServiceName())
   146  	if !ok {
   147  		return c.Server.eHandle.LWarpErrorDesc(
   148  			errorhandler.ServiceNotfound, c.Message.GetServiceName())
   149  	}
   150  	c.Service = service
   151  	return nil
   152  }
   153  
   154  func (c *messageOpt) checkCallArgs() (values []reflect.Value, err perror.LErrorDesc) {
   155  	// 去除接收者之后的输入参数长度
   156  	// 校验客户端传递的参数和服务端是否一致
   157  	iter := c.Message.PayloadsIterator()
   158  	if nInput := len(c.Service.ArgsType) - metaDataUtil.InputOffset(c.Service); nInput != iter.Tail() {
   159  		return nil, c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrServer,
   160  			"client input args number no equal server",
   161  			fmt.Sprintf("Client : %d", iter.Tail()), fmt.Sprintf("Server : %d", nInput))
   162  	}
   163  	// 哨兵条件, 过程不要求任何输入时即可以提前结束
   164  	if len(c.Service.ArgsType) == 0 {
   165  		return
   166  	}
   167  	defer func() {
   168  		if err == nil {
   169  			return
   170  		}
   171  		if c.Cancel != nil {
   172  			c.Cancel()
   173  		}
   174  	}()
   175  	var callArgs []reflect.Value
   176  	var inputStart int
   177  	if c.Service.Option.CompleteReUsage {
   178  		callArgs = c.Service.Pool.Get().([]reflect.Value)
   179  		defer func() {
   180  			if err != nil {
   181  				c.Service.Pool.Put(&callArgs)
   182  			}
   183  		}()
   184  		inputStart, err = c.checkContextAndStream(callArgs, true)
   185  		if err != nil {
   186  			return
   187  		}
   188  	} else {
   189  		callArgs = reflect2.FuncInputTypeListReturnValue(c.Service.ArgsType, 0, func(i int) bool {
   190  			if len(iter.Take()) == 0 {
   191  				return true
   192  			}
   193  			return false
   194  		}, true)
   195  		inputStart, err = c.checkContextAndStream(callArgs, true)
   196  		if err != nil {
   197  			return
   198  		}
   199  	}
   200  	iter.Reset()
   201  	for i := inputStart; i < len(callArgs) && iter.Next(); i++ {
   202  		callArg, err := check.UnMarshalFromUnsafe(c.Codec, iter.Take(), callArgs[i].Interface())
   203  		if err != nil {
   204  			return nil, c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrCodecMarshalError, err.Error())
   205  		}
   206  		// 可以根据获取的参数类别的每一个参数的类型信息得到
   207  		// 所需的精确类型,所以不用再对变长的类型做处理
   208  		callArgs[i] = reflect.ValueOf(callArg)
   209  	}
   210  	return callArgs, nil
   211  }
   212  
   213  func (c *messageOpt) getContext() (context.Context, perror.LErrorDesc) {
   214  	ctx := context.Background()
   215  	ctxIdStr, ok := c.Message.MetaData.LoadOk(message.ContextId)
   216  	// 客户端携带context-id且对应的过程支持context时才注册context
   217  	// 为不支持context的过程注册时无意义的且可能会导致context泄漏
   218  	if ok && c.Service.SupportContext {
   219  		ctxId, err := strconv.ParseUint(ctxIdStr, 10, 64)
   220  		if err != nil {
   221  			return nil, c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrServer, err.Error())
   222  		}
   223  		rawCtx, _ := c.Desc.ctxManager.RegisterContextCancel(ctxId)
   224  		ctx, c.Cancel = context.WithCancel(rawCtx)
   225  		if err != nil {
   226  			return nil, c.Server.eHandle.LWarpErrorDesc(errorhandler.ErrServer, err.Error())
   227  		}
   228  	}
   229  	ctx = rContext.WithLocalAddr(ctx, c.Desc.localAddr)
   230  	ctx = rContext.WithRemoteAddr(ctx, c.Desc.remoteAddr)
   231  	return ctx, nil
   232  }
   233  
   234  func (c *messageOpt) checkContextAndStream(callArgs container.Slice[reflect.Value], write bool) (offset int, err perror.LErrorDesc) {
   235  	ctx, err := c.getContext()
   236  	if err != nil {
   237  		return 0, err
   238  	}
   239  	callArgs.Reset()
   240  	switch {
   241  	case c.Service.SupportContext:
   242  		offset = 1
   243  		if write {
   244  			callArgs.AppendSingle(reflect.ValueOf(ctx))
   245  		}
   246  	case c.Service.SupportContext && c.Service.SupportStream:
   247  		offset = 2
   248  		if write {
   249  			callArgs.AppendS(reflect.ValueOf(ctx), reflect.ValueOf(*new(stream.LStream)))
   250  		}
   251  	case c.Service.SupportStream:
   252  		offset = 1
   253  		if write {
   254  			callArgs.AppendSingle(reflect.ValueOf(*new(stream.LStream)))
   255  		}
   256  	default:
   257  		// 不支持context&stream
   258  		break
   259  	}
   260  	return
   261  }
   262  
   263  func (c *messageOpt) initReplyMsg(msg *message.Message, msgId uint64) {
   264  	msg.SetMsgType(message.Return)
   265  	msg.SetMsgId(msgId)
   266  	if c.Codec.Scheme() != message.DefaultCodec {
   267  		msg.MetaData.Store(message.CodecScheme, c.Codec.Scheme())
   268  	}
   269  	if c.Packer.Scheme() != message.DefaultPacker {
   270  		msg.MetaData.Store(message.PackerScheme, c.Packer.Scheme())
   271  	}
   272  }
   273  
   274  func injectPluginContext(ctx context.Context, msgType uint8, service string, start time.Time) context.Context {
   275  	ctx = rContext.WithInitData(ctx, &rContext.InitData{
   276  		Start:       start,
   277  		ServiceName: service,
   278  		MsgType:     msgType,
   279  	})
   280  	return ctx
   281  }