github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/common/msgparser/lrpc_msgparser.go (about)

     1  package msgparser
     2  
     3  import (
     4  	"errors"
     5  	"fmt"
     6  	"github.com/nyan233/littlerpc/core/container"
     7  	message2 "github.com/nyan233/littlerpc/core/protocol/message"
     8  	"sync"
     9  )
    10  
    11  type readyBuffer struct {
    12  	MsgId         uint64
    13  	PayloadLength uint32
    14  	RawBytes      []byte
    15  }
    16  
    17  func (b *readyBuffer) IsUnmarshal() bool {
    18  	return b.MsgId != 0 && b.PayloadLength != 0
    19  }
    20  
    21  // lRPCTrait 特征化Parser, 根据Header自动选择适合的Handler
    22  type lRPCTrait struct {
    23  	mu sync.Mutex
    24  	// 简单的分配器接口, 用于分配可复用的Message
    25  	allocTor AllocTor
    26  	// 下一个状态的触发间隔, 也就是距离转移到下一个状态需要读取的数据量
    27  	clickInterval int
    28  	// 当前parser选中的handler
    29  	handler MessageHandler
    30  	// 当前在状态机中处于的状态
    31  	state int
    32  	// 单次解析数据的起始偏移量
    33  	startOffset int
    34  	// 行内指针, 指示在start-end中目前的解析处于哪个位置
    35  	linePtr int
    36  	// 单次解析数据的结束偏移量
    37  	endOffset int
    38  	// 消息的最小解析长度, 包括1Byte的Header
    39  	msgBaseLen int
    40  	// 用于存储未完成的消息, 用于mux
    41  	noReadyBuffer map[uint64]readyBuffer
    42  	// 存储半包数据的缓冲区, 只有在读完了一个完整的payload的消息的数据包
    43  	// 才会被直接提升到noReadyBuffer中, noMux类型的数据包则不会被提升到
    44  	// noReadyBuffer中, 将完整的消息读取完毕后直接触发onComplete
    45  	halfBuffer container.ByteSlice
    46  }
    47  
    48  func NewLRPCTrait(allocTor AllocTor, bufSize uint32) Parser {
    49  	if bufSize > MaxBufferSize {
    50  		bufSize = MaxBufferSize
    51  	} else if bufSize == 0 {
    52  		bufSize = DefaultBufferSize
    53  	}
    54  	return &lRPCTrait{
    55  		allocTor:      allocTor,
    56  		clickInterval: 1,
    57  		state:         _ScanInit,
    58  		noReadyBuffer: make(map[uint64]readyBuffer, 16),
    59  		halfBuffer:    make([]byte, 0, bufSize),
    60  	}
    61  }
    62  
    63  func (h *lRPCTrait) ParseOnReader(reader func([]byte) (n int, err error)) (msgs []ParserMessage, err error) {
    64  	h.mu.Lock()
    65  	defer h.mu.Unlock()
    66  	currentLen := h.halfBuffer.Len()
    67  	currentCap := h.halfBuffer.Cap()
    68  	h.halfBuffer = h.halfBuffer[:currentCap]
    69  	for i := 0; i < 16; i++ {
    70  		readN, err := reader(h.halfBuffer[currentLen:currentCap])
    71  		if readN > 0 {
    72  			currentLen += readN
    73  		}
    74  		// read full
    75  		if currentLen == currentCap {
    76  			break
    77  		}
    78  		if err != nil {
    79  			break
    80  		}
    81  	}
    82  	h.halfBuffer = h.halfBuffer[:currentLen]
    83  	return h.parseFromHalfBuffer()
    84  }
    85  
    86  // Parse io.Reader主要用来标识一个读取到半包的连接, 并不会真正去调用他的方法
    87  func (h *lRPCTrait) Parse(data []byte) (msgs []ParserMessage, err error) {
    88  	h.mu.Lock()
    89  	defer h.mu.Unlock()
    90  	if h.clickInterval == 1 && len(data) == 0 {
    91  		return nil, errors.New("data length == 0")
    92  	}
    93  	h.halfBuffer.Append(data)
    94  	return h.parseFromHalfBuffer()
    95  }
    96  
    97  func (h *lRPCTrait) memSwap() {
    98  	if !(h.startOffset+h.clickInterval > h.halfBuffer.Cap()) {
    99  		return
   100  	}
   101  
   102  }
   103  
   104  func (h *lRPCTrait) parseFromHalfBuffer() (msgs []ParserMessage, err error) {
   105  	allMsg := make([]ParserMessage, 0, 4)
   106  	defer func() {
   107  		if err != nil {
   108  			h.ResetScan()
   109  			if len(allMsg) == 0 {
   110  				return
   111  			}
   112  			for _, msg := range allMsg {
   113  				h.allocTor.FreeMessage(msg.Message)
   114  			}
   115  		}
   116  	}()
   117  	var ableParse bool
   118  	var parseInterrupt bool
   119  	for {
   120  		if parseInterrupt {
   121  			break
   122  		}
   123  		if ableParse && h.clickInterval > h.halfBuffer.Len()-h.startOffset {
   124  			break
   125  		}
   126  		// scan all
   127  		if h.halfBuffer.Len() == h.endOffset {
   128  			h.halfBuffer.Reset()
   129  			h.startOffset = 0
   130  			h.endOffset = 0
   131  			h.linePtr = 0
   132  			h.msgBaseLen = 0
   133  			return allMsg, nil
   134  		}
   135  		switch h.state {
   136  		case _ScanInit:
   137  			err := h.handleScanInit(&allMsg)
   138  			if err != nil {
   139  				return nil, err
   140  			}
   141  		case _ScanMsgParse1:
   142  			next, err := h.handleScanParse1(&allMsg)
   143  			if err != nil {
   144  				return nil, err
   145  			}
   146  			if !next {
   147  				parseInterrupt = true
   148  			} else {
   149  				ableParse = true
   150  			}
   151  		case _ScanMsgParse2:
   152  			next, err := h.handleScanParse2(&allMsg)
   153  			if err != nil {
   154  				return nil, err
   155  			}
   156  			if !next {
   157  				parseInterrupt = true
   158  			} else {
   159  				ableParse = true
   160  			}
   161  		}
   162  	}
   163  	// 最后的数据不满足长度要求则可以搬迁数据, 至少要经过一次完整的解析
   164  	if ableParse && (h.halfBuffer.Len()-h.startOffset < h.clickInterval) && h.startOffset > 0 {
   165  		oldBuffer := h.halfBuffer
   166  		h.halfBuffer = h.halfBuffer[:h.halfBuffer.Len()-h.startOffset]
   167  		copy(h.halfBuffer, oldBuffer[h.startOffset:])
   168  		h.endOffset = h.endOffset - h.startOffset
   169  		h.startOffset = 0
   170  		h.linePtr = h.endOffset
   171  		h.msgBaseLen = 0
   172  	}
   173  	return allMsg, nil
   174  }
   175  
   176  func (h *lRPCTrait) handleScanInit(allMsg *[]ParserMessage) (err error) {
   177  	if handler := GetHandler(h.halfBuffer[h.startOffset]); handler != nil {
   178  		h.handler = handler
   179  	} else {
   180  		return errors.New(fmt.Sprintf("MagicNumber no MessageHandler -> %d", (h.halfBuffer)[0]))
   181  	}
   182  	h.state = _ScanMsgParse1
   183  	opt, baseLen := h.handler.BaseLen()
   184  	if opt == SingleRequest {
   185  		msg := h.allocTor.AllocMessage()
   186  		msg.Reset()
   187  		defer func() {
   188  			if err != nil {
   189  				h.allocTor.FreeMessage(msg)
   190  			}
   191  			h.ResetScan()
   192  		}()
   193  		action, err := h.handler.Unmarshal(h.halfBuffer, msg)
   194  		if err != nil {
   195  			return err
   196  		}
   197  		h.linePtr = h.halfBuffer.Len()
   198  		h.endOffset = h.halfBuffer.Len()
   199  		err = h.handleAction(action, h.halfBuffer, msg, allMsg, nil)
   200  		if err != nil {
   201  			return err
   202  		}
   203  		return nil
   204  	}
   205  	h.msgBaseLen = baseLen
   206  	h.clickInterval = baseLen - 1
   207  	h.endOffset++
   208  	h.linePtr++
   209  	return nil
   210  }
   211  
   212  func (h *lRPCTrait) handleScanParse1(allMsg *[]ParserMessage) (next bool, err error) {
   213  	interval := h.halfBuffer.Len() - h.linePtr
   214  	if interval < 0 {
   215  		return false, errors.New("no read buf")
   216  	}
   217  	if interval < h.clickInterval {
   218  		return false, nil
   219  	}
   220  	interval = h.clickInterval
   221  	h.linePtr += interval
   222  	h.endOffset += interval
   223  	h.clickInterval = 0
   224  	_, baseLen := h.handler.BaseLen()
   225  	h.clickInterval = h.handler.MessageLength(h.halfBuffer[h.startOffset:h.endOffset]) - baseLen
   226  	h.state = _ScanMsgParse2
   227  	next = true
   228  	return
   229  }
   230  
   231  func (h *lRPCTrait) handleScanParse2(allMsg *[]ParserMessage) (next bool, err error) {
   232  	interval := h.halfBuffer.Len() - h.linePtr
   233  	if interval < 0 {
   234  		return false, errors.New("no read buf")
   235  	}
   236  	if interval < h.clickInterval {
   237  		return false, nil
   238  	}
   239  	interval = h.clickInterval
   240  	h.linePtr += interval
   241  	h.endOffset += interval
   242  	h.clickInterval = 0
   243  	msg := h.allocTor.AllocMessage()
   244  	msg.Reset()
   245  	action, err := h.handler.Unmarshal(h.halfBuffer[h.startOffset:h.endOffset], msg)
   246  	if err != nil {
   247  		h.allocTor.FreeMessage(msg)
   248  		return false, err
   249  	}
   250  	err = h.handleAction(action, h.halfBuffer, msg, allMsg, h.halfBuffer[h.startOffset+h.msgBaseLen:h.endOffset])
   251  	if err != nil {
   252  		h.allocTor.FreeMessage(msg)
   253  		return false, err
   254  	}
   255  	h.ResetScan()
   256  	h.startOffset = h.endOffset
   257  	next = true
   258  	return
   259  }
   260  
   261  func (h *lRPCTrait) Free(msg *message2.Message) {
   262  	h.allocTor.FreeMessage(msg)
   263  }
   264  
   265  func (h *lRPCTrait) Reset() {
   266  	h.ResetScan()
   267  }
   268  
   269  func (h *lRPCTrait) ResetScan() {
   270  	h.handler = nil
   271  	h.clickInterval = 1
   272  	h.state = _ScanInit
   273  }
   274  
   275  func (h *lRPCTrait) deleteNoReadyBuffer(msgId uint64) {
   276  	// 置空/删除Map Key让内存得以回收
   277  	h.noReadyBuffer[msgId] = readyBuffer{}
   278  	delete(h.noReadyBuffer, msgId)
   279  }
   280  
   281  // State 下个状态的触发间隔&当前的状态&缓冲区的长度
   282  func (h *lRPCTrait) State() (int, int, int) {
   283  	h.mu.Lock()
   284  	defer h.mu.Unlock()
   285  	return h.clickInterval, h.state, len(h.halfBuffer)
   286  }
   287  
   288  func (h *lRPCTrait) handleAction(action Action, buf container.ByteSlice, msg *message2.Message, allMsg *[]ParserMessage, readData []byte) error {
   289  	switch action {
   290  	case UnmarshalBase:
   291  		readBuf, ok := h.noReadyBuffer[msg.GetMsgId()]
   292  		if !ok {
   293  			readBuf = readyBuffer{
   294  				MsgId:         msg.GetMsgId(),
   295  				PayloadLength: msg.Length(),
   296  				RawBytes:      buf,
   297  			}
   298  		} else {
   299  			readBuf.RawBytes = append(readBuf.RawBytes, readData...)
   300  		}
   301  		if uint32(len(readBuf.RawBytes)) == readBuf.PayloadLength {
   302  			defer h.deleteNoReadyBuffer(msg.GetMsgId())
   303  			msg.Reset()
   304  			err := message2.Unmarshal(readBuf.RawBytes, msg)
   305  			if err != nil {
   306  				return err
   307  			}
   308  			*allMsg = append(*allMsg, ParserMessage{
   309  				Message: msg,
   310  				Header:  readBuf.RawBytes[0],
   311  			})
   312  		} else if !ok {
   313  			readBuf.RawBytes = append([]byte{}, readData...)
   314  			// mux中的消息不能一次性序列化完成则释放预分配的msg
   315  			h.allocTor.FreeMessage(msg)
   316  		}
   317  		h.noReadyBuffer[msg.GetMsgId()] = readBuf
   318  	case UnmarshalComplete:
   319  		*allMsg = append(*allMsg, ParserMessage{
   320  			Message: msg,
   321  			Header:  buf[h.startOffset],
   322  		})
   323  	}
   324  	return nil
   325  }