github.com/nyan233/littlerpc@v0.4.6-0.20230316182519-0c8d5c48abaf/core/common/msgparser/parser_test.go (about)

     1  package msgparser
     2  
     3  import (
     4  	"encoding/json"
     5  	"errors"
     6  	"github.com/nyan233/littlerpc/core/common/jsonrpc2"
     7  	"github.com/nyan233/littlerpc/core/container"
     8  	message2 "github.com/nyan233/littlerpc/core/protocol/message"
     9  	"github.com/nyan233/littlerpc/core/protocol/message/gen"
    10  	mux2 "github.com/nyan233/littlerpc/core/protocol/message/mux"
    11  	"github.com/nyan233/littlerpc/core/utils/random"
    12  	"github.com/stretchr/testify/assert"
    13  	"github.com/stretchr/testify/suite"
    14  	"math"
    15  	"strconv"
    16  	"strings"
    17  	"sync"
    18  	"testing"
    19  )
    20  
    21  type ParserFullTest struct {
    22  	parser Parser
    23  	suite.Suite
    24  }
    25  
    26  func TestParser(t *testing.T) {
    27  	suite.Run(t, new(ParserFullTest))
    28  }
    29  
    30  func (f *ParserFullTest) SetupTest() {
    31  	f.parser = Get(DefaultParser)(NewDefaultSimpleAllocTor(), 4096).(*lRPCTrait)
    32  }
    33  
    34  func (f *ParserFullTest) TestLRPCParser() {
    35  	t := f.T()
    36  	t.Run("ParseOnData", testParser(f.parser.(*lRPCTrait), func(data []byte) ([]ParserMessage, error) {
    37  		return f.parser.Parse(data)
    38  	}))
    39  	t.Run("ParseOnReader", testParser(f.parser.(*lRPCTrait), func(data []byte) ([]ParserMessage, error) {
    40  		var read bool
    41  		return f.parser.ParseOnReader(func(bytes []byte) (n int, err error) {
    42  			if read {
    43  				return -1, errors.New("already read")
    44  			}
    45  			read = true
    46  			return copy(bytes, data), nil
    47  		})
    48  	}))
    49  }
    50  
    51  func testParser(p *lRPCTrait, parseFunc func(data []byte) ([]ParserMessage, error)) func(t *testing.T) {
    52  	return func(t *testing.T) {
    53  		msg := message2.New()
    54  		msg.SetMsgId(uint64(random.FastRand()))
    55  		msg.SetServiceName("TestParser/LocalTest")
    56  		msg.MetaData.Store("Key", "Value")
    57  		msg.MetaData.Store("Key2", "Value2")
    58  		msg.MetaData.Store("Key3", "Value3")
    59  		msg.AppendPayloads([]byte("hello world"))
    60  		msg.AppendPayloads([]byte("65536"))
    61  		msg.Length()
    62  		var marshalBytes []byte
    63  		err := message2.Marshal(msg, (*container.Slice[byte])(&marshalBytes))
    64  		assert.NoError(t, err)
    65  		muxBlock := mux2.Block{
    66  			Flags:    mux2.Enabled,
    67  			StreamId: random.FastRand(),
    68  			MsgId:    uint64(random.FastRand()),
    69  		}
    70  		muxBlock.SetPayloads(marshalBytes)
    71  		var muxMarshalBytes []byte
    72  		mux2.Marshal(&muxBlock, (*container.Slice[byte])(&muxMarshalBytes))
    73  		marshalBytes = append(marshalBytes, muxMarshalBytes...)
    74  		_, err = parseFunc(marshalBytes[:11])
    75  		assert.NoError(t, err)
    76  		allMasg, err := parseFunc(marshalBytes[11 : msg.Length()+20])
    77  		assert.NoError(t, err)
    78  		assert.Equal(t, len(allMasg), 1)
    79  		allMasg, err = parseFunc(marshalBytes[msg.Length()+20:])
    80  		assert.NoError(t, err)
    81  		assert.Equal(t, len(allMasg), 1)
    82  		assert.Equal(t, len(p.halfBuffer), 0)
    83  		assert.Equal(t, p.startOffset, 0)
    84  		assert.Equal(t, p.endOffset, 0)
    85  		assert.Equal(t, p.clickInterval, 1)
    86  	}
    87  }
    88  
    89  func (f *ParserFullTest) TestConcurrentHalfParse() {
    90  	const (
    91  		ConsumerSize   = 16
    92  		ChanBufferSize = 8
    93  		CycleSize      = 1000
    94  		OnePushMax     = 20
    95  	)
    96  	t := f.T()
    97  	producer := func(channels []chan []byte, data []byte, cycleSize int) {
    98  		for i := 0; i < cycleSize; i++ {
    99  			tmpData := data
   100  			for len(tmpData) > 0 {
   101  				var readN int
   102  				if len(tmpData) >= OnePushMax {
   103  					readN = OnePushMax
   104  				} else {
   105  					readN = len(tmpData)
   106  				}
   107  				for _, channel := range channels {
   108  					channel <- tmpData[:readN]
   109  				}
   110  				tmpData = tmpData[readN:]
   111  			}
   112  		}
   113  		for _, channel := range channels {
   114  			close(channel)
   115  		}
   116  	}
   117  	consumer := func(parser Parser, channel chan []byte, checkHeader byte, wg *sync.WaitGroup) {
   118  		defer wg.Done()
   119  		for {
   120  			select {
   121  			case data, ok := <-channel:
   122  				if !ok {
   123  					return
   124  				}
   125  				msgs, err := parser.Parse(data)
   126  				if err != nil {
   127  					t.Error(err)
   128  				}
   129  				if msgs != nil && len(msgs) > 0 {
   130  					for _, msg := range msgs {
   131  						assert.Equal(t, checkHeader, msg.Header)
   132  						parser.Free(msg.Message)
   133  					}
   134  				}
   135  			}
   136  		}
   137  	}
   138  	consumerChannels := make([]chan []byte, ConsumerSize)
   139  	for k := range consumerChannels {
   140  		consumerChannels[k] = make(chan []byte, ChanBufferSize)
   141  	}
   142  	var wg sync.WaitGroup
   143  	wg.Add(ConsumerSize)
   144  	for _, v := range consumerChannels {
   145  		go consumer(NewLRPCTrait(NewDefaultSimpleAllocTor(), MaxBufferSize), v, message2.MagicNumber, &wg)
   146  	}
   147  	go producer(consumerChannels, gen.NoMuxToBytes(gen.Big), CycleSize)
   148  	wg.Wait()
   149  }
   150  
   151  func (f *ParserFullTest) TestJsonRPC2Parser() {
   152  	t := f.T()
   153  	request := new(jsonrpc2.Request)
   154  	request.Version = jsonrpc2.Version
   155  	request.MessageType = int(message2.Call)
   156  	request.Method = "Test.JsonRPC2Case1"
   157  	request.MetaData = map[string]string{
   158  		"context-id": strconv.FormatInt(int64(random.FastRand()), 10),
   159  		"streamId":   strconv.FormatInt(int64(random.FastRand()), 10),
   160  		"codec":      "json",
   161  		"packer":     "text",
   162  	}
   163  	request.Id = uint64(random.FastRand())
   164  	request.Params = []byte("[1203,\"hello world\",3563]")
   165  	bytes, err := json.Marshal(request)
   166  	assert.NoError(t, err)
   167  	parser := f.parser
   168  	msg, err := parser.Parse(bytes)
   169  	assert.Nil(t, err, err)
   170  	assert.Equal(t, len(msg), 1)
   171  
   172  	iter := msg[0].Message.PayloadsIterator()
   173  	assert.Equal(t, iter.Tail(), 3)
   174  	var i int
   175  	for iter.Next() {
   176  		i++
   177  		switch i {
   178  		case 1:
   179  			assert.Equal(t, string(iter.Take()), "1203")
   180  		case 2:
   181  			assert.Equal(t, string(iter.Take()), "\"hello world\"")
   182  		case 3:
   183  			assert.Equal(t, string(iter.Take()), "3563")
   184  		}
   185  	}
   186  	assert.Equal(t, msg[0].Message.GetServiceName(), "Test.JsonRPC2Case1")
   187  
   188  	// 测试是否能够处理错误的消息类型
   189  	request.MessageType = 0x889839
   190  	bytes, err = json.Marshal(request)
   191  	assert.Nil(t, err, err)
   192  	msg, err = parser.Parse(bytes)
   193  	assert.NotNil(t, err, "input error data but marshal able")
   194  }
   195  
   196  func TestHandler(t *testing.T) {
   197  	for i := uint8(0); true; i++ {
   198  		GetHandler(i)
   199  		if i == math.MaxUint8 {
   200  			break
   201  		}
   202  	}
   203  	defer func() {
   204  		assert.NotNil(t, recover())
   205  	}()
   206  	RegisterHandler(nil)
   207  }
   208  
   209  func parserOnBytes(s string) []byte {
   210  	s = s[1 : len(s)-1]
   211  	sp := strings.Split(s, " ")
   212  	bs := make([]byte, 0, len(sp))
   213  	for _, ss := range sp {
   214  		b, _ := strconv.Atoi(ss)
   215  		bs = append(bs, byte(b))
   216  	}
   217  	return bs
   218  }